Node-level problems
Classification
- class NCScenarioLoader(dataset_name, save_path, num_tasks, incr_type, metric, **kwargs)[source]
The sceanario loader for node classification problems.
Usage example:
>>> scenario = NCScenarioLoader(dataset_name dataset_object=None, num_tasks=3, metric="accuracy", ... save_path="./data", incr_type="task", task_shuffle=True)
Bases:
BaseScenarioLoader- _get_eval_result_inner(preds, target_split)[source]
The inner function of get_eval_result.
- Parameters:
preds (torch.Tensor) – predicted output of the current model
target_split (str) – target split to measure the performance (spec., ‘val’ or ‘test’)
- _init_continual_scenario()[source]
Load the entire dataset and initialize the setting of graph continual learning according the incremental setting.
- _update_accumulated_dataset()[source]
Update the graph dataset the joint model uses. According to the
incr_type, the information updated is different.Note
The joint model can process all of training data in previous tasks including the current task.
- _update_target_dataset()[source]
Update the graph dataset the implemented model uses in the current task. According to the
incr_type, the information updated is different.Note
The implemented model can only process the training data in the current task.
- get_accum_eval_result(preds, target_split='test')[source]
Compute performance on the accumulated dataset for the given target split. It can be used to compute train/val performance during training.
- Parameters:
preds (torch.Tensor) – predicted output of the current model
target_split (str) – target split to measure the performance (spec., ‘val’ or ‘test’)
- get_current_dataset_for_export(_global=False)[source]
- Returns:
The graph dataset the implemented model uses in the current task
- get_simple_eval_result(curr_batch_preds, curr_batch_gts)[source]
Compute performance for the given batch when we ignore task configuration. It can be used to compute train/val performance during training.
- Parameters:
curr_batch_preds (torch.Tensor) – predicted output of the current model
curr_batch_gts (torch.Tensor) – ground-truth labels
- next_task(preds=torch.empty)[source]
Update graph datasets used in graph continual learning. Specifically, the
targetdenotes a dataset the implemented model uses and theaccumulateddenotes a dataset the joint model uses.- Parameters:
preds (torch.Tensor) – Predicted output of the models