Graph-level problems

class GCScenarioLoader(dataset_name, save_path, num_tasks, incr_type, metric, **kwargs)[source]

The sceanario loader for graph classification problems.

Usage example:

>>> scenario = GCScenarioLoader(dataset_name="ogbg-molhiv", num_tasks=10, metric="rocauc",
...                             save_path="./data", incr_type="domain", task_shuffle=True)

Bases: BaseScenarioLoader

_get_eval_result_inner(preds, target_split)[source]

The inner function of get_eval_result.

Parameters:
  • preds (torch.Tensor) – predicted output of the current model

  • target_split (str) – target split to measure the performance (spec., ‘val’ or ‘test’)

_init_continual_scenario()[source]

Load the entire dataset and initialize the setting of graph continual learning according the incremental setting.

_update_accumulated_dataset()[source]

Update the graph dataset the joint model uses. According to the incr_type, the information updated is different.

Note

The joint model can process all of training data in previous tasks including the current task.

_update_target_dataset()[source]

Update the graph dataset the implemented model uses in the current task. According to the incr_type, the information updated is different.

Note

The implemented model can only process the training data in the current task.

get_accum_eval_result(preds, target_split='test')[source]

Compute performance on the accumulated dataset for the given target split. It can be used to compute train/val performance during training.

Parameters:
  • preds (torch.Tensor) – predicted output of the current model

  • target_split (str) – target split to measure the performance (spec., ‘val’ or ‘test’)

get_current_dataset_for_export(_global=False)[source]
Returns:

The graph dataset the implemented model uses in the current task

get_simple_eval_result(curr_batch_preds, curr_batch_gts)[source]

Compute performance for the given batch when we ignore task configuration. It can be used to compute train/val performance during training.

Parameters:
  • curr_batch_preds (torch.Tensor) – predicted output of the current model

  • curr_batch_gts (torch.Tensor) – ground-truth labels

next_task(preds=torch.empty)[source]

Update graph datasets used in graph continual learning. Specifically, the target denotes a dataset the implemented model uses and the accumulated denotes a dataset the joint model uses.

Parameters:

preds (torch.Tensor) – Predicted output of the models