Common framework
Our framework BeGin provides a sceanrio loader responsible for communicating with user code (i.e., the training part) to perform a benchmark under a sesired incremental setting.
According to the graph problems, it will be different that users need. Therefore, we provide 4 scenario loaders and a base framework for further implementation.
We provides the implemented scenario loaders as follows:
The base framework is as follows:
- class BaseScenarioLoader(dataset_name, save_path, num_tasks, incr_type, metric, **kwargs)[source]
Base framework for implementing scenario module.
- Parameters:
dataset_name (str) – The name of the dataset.
save_path (str) – The path where the dataset file is saved.
num_tasks (int) – The number of tasks in graph continual learning.
incr_type (str) – The incremental setting of graph continual learning (spec. task, class, domain, and time).
metric (str) – Basic metric to measure performance (spec., accuracy, AUCROC, and HITS@K).
kwargs – (dict, optional): Keyword arguments to be passed to the scenario module (e.g., task_shuffle (bool): If true, fixed order, else random order)
- _init_continual_scenario()[source]
Load the entire dataset and initialize the setting of graph continual learning according the incremental setting.
- _update_accumulated_dataset()[source]
Update the graph dataset the joint model uses. According to the
incr_type, the information updated is different.Note
The joint model can process all of training data in previous tasks including the current task.
- _update_target_dataset()[source]
Update the graph dataset the implemented model uses in the current task. According to the
incr_type, the information updated is different.Note
The implemented model can only process the training data in the current task.
- export_dataset(full=True)[source]
Export the continual learning scenarios. We provide this functionality for flexibility, by providing the datasets separately. Depending on the target problem, the output format varies. Detailed format information is as follows:
(1) For NC, the global information contains number of classes (
num_classes), node features (ndata_feat), task id for each node (task), edges of the whole graph (edges), mask for data splits (train_mask,val_mask, andtest_mask), and ground-truth labels (label). For each task, the task-specific information contains edges of the graph observed at the current task (edges), mask for train/validation splits (train_mask,val_mask), and the ground-truth labels observed at the current task (label).(2) For LC, the global information contains number of classes (
num_classes), node features (ndata_feat), task id for each edge (task), edges of the whole graph (edges), mask for data splits (train_mask,val_mask, andtest_mask), and ground-truth labels (label). For each task, the task-specific information contains edges of the graph observed at the current task (edges), mask for train/validation splits (train_mask,val_mask), and the ground-truth labels observed at the current task (label).(3) For LP, the global information contains node features (
ndata_feat), task id for each edge (task), edges of the whole graph (edges), negative edges for evaluation (neg_edges), edges containing test edges and negative edges for evaluation (test_edges), and their corresponding labels (test_label). The test label islis 0 if it is negative edge, otherwise it is ground-truth edge for taskl. For each task, the task-specific information contains base edges observed at the current task (edges), edges for training prediction problem and their corresponding labels (val_edges,val_label), and edges for validation and their corresponding labels (test_edges,test_label).(4) For GC, the global information contains number of classes (
num_classes), node features (ndata_feat), task id for each graph (task), graphs of the whole dataset (graphs), mask for data splits (train_mask,val_mask, andtest_mask), indices of test graphs (test_indices) and ground-truth labels (label). For each task, the task-specific information contains indices of training graphs (train_indices) and indices of validation graphs (val_indices).- Parameters:
full (boolean, Optional) – if
full=True, the returned exported dataset contains both global information (output['global']) and task-specific information (output['tasks']).Otherwise –
infomation. (the returned exported dataset contains only global) –
- Returns:
The exported scenario (dict).
- get_accumulated_dataset()[source]
- Returns:
The graph dataset the joint model uses in the current task
- get_current_dataset()[source]
- Returns:
The graph dataset the implemented model uses in the current task
- next_task(preds)[source]
Update graph datasets used in graph continual learning. Specifically, the
targetdenotes a dataset the implemented model uses and theaccumulateddenotes a dataset the joint model uses.- Parameters:
preds (torch.Tensor) – Predicted output of the models