import torch
import dgl
import os
import pickle
import copy
from dgl.data.utils import download, Subset
from .common import BaseScenarioLoader
from .datasets import *
from . import evaluator_map
def load_graph_dataset(dataset_name, dataset_load_func, incr_type, save_path):
domain_info, time_info = None, None
if dataset_load_func is not None:
custom_dataset = dataset_load_func(save_path=save_path)
dataset = custom_dataset['graphs']
num_feats = custom_dataset['num_feats']
num_classes = custom_dataset['num_classes']
domain_info = custom_dataset.get('domain_info', None)
time_info = custom_dataset.get('time_info', None)
if dataset_name in ['mnist', 'cifar10'] and incr_type in ['task', 'class']:
dataset = DGLGNNBenchmarkDataset(dataset_name, raw_dir=save_path)
num_feats, num_classes = dataset.num_feats, dataset.num_classes
elif dataset_name in ['aromaticity'] and incr_type in ['task', 'class']:
dataset = AromaticityDataset(raw_dir=save_path)
num_feats, num_classes = 2, 30
# load train/val/test split (6:2:2 random split)
pkl_path = os.path.join(save_path, f'{dataset_name}_metadata_allIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/{dataset_name}_metadata_allIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
dataset._train_masks = (inner_tvt_splits % 10) < 6
dataset._val_masks = ((inner_tvt_splits % 10) == 6) | ((inner_tvt_splits % 10) == 7)
dataset._test_masks = (inner_tvt_splits % 10) > 7
elif dataset_name in ['ogbg-molhiv'] and incr_type in ['domain']:
dataset = DglGraphPropPredDataset('ogbg-molhiv', root=save_path)
num_feats, num_classes = dataset[0][0].ndata['feat'].shape[-1], 1
""" (For Task/Class-IL)
# load train/val/test split
split_idx = dataset.get_idx_split()
for _split, _split_name in [('train', '_train'), ('valid', '_val'), ('test', '_test')]:
_indices = torch.zeros(len(dataset), dtype=torch.bool)
_indices[split_idx[_split]] = True
setattr(dataset, _split_name + '_mask', _indices)
"""
# load train/val/test split and domain_info
pkl_path = os.path.join(save_path, f'molhivx_metadata_domainIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/molhivx_metadata_domainIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
# set train/val/test split (random split, 8:1:1)
dataset._train_masks = (inner_tvt_splits % 10) < 8
dataset._val_masks = (inner_tvt_splits % 10) == 8
dataset._test_masks = (inner_tvt_splits % 10) > 8
domain_info = metadata['domain_splits']
elif dataset_name in ['nyctaxi'] and incr_type in ['time']:
dataset = NYCTaxiDataset(dataset_name, raw_dir=save_path)
num_feats, num_classes = dataset[0][0].ndata['feat'].shape[-1], 2
# load time split information and train/val/test splits (random split, 6:2:2)
pkl_path = os.path.join(save_path, f'nyctaxi_metadata_timeIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/nyctaxi_metadata_timeIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
dataset._train_masks = (inner_tvt_splits % 10) < 6
dataset._val_masks = ((inner_tvt_splits % 10) == 6) | ((inner_tvt_splits % 10) == 7)
dataset._test_masks = (inner_tvt_splits % 10) > 7
time_info = metadata['time_splits']
elif dataset_name in ['ogbg-ppa'] and incr_type in ['domain']:
dataset = OgbgPpaSampledDataset(save_path)
num_feats, num_classes = 2, 37
pkl_path = os.path.join(save_path, f'ogbg-ppa_metadata_domainIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/ogbg-ppa_metadata_domainIL.pkl', path=pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_split']
# set train/val/test split (random split, 8:1:1)
dataset._train_masks = (inner_tvt_splits % 10) < 8
dataset._val_masks = (inner_tvt_splits % 10) == 8
dataset._test_masks = (inner_tvt_splits % 10) > 8
domain_info = metadata['domain_info']
elif dataset_name in ['sentiment'] and incr_type in ['time']:
dataset = SentimentGraphDataset(dataset_name='sentiment', raw_dir=save_path)
num_feats, num_classes = dataset._num_feats, dataset._num_classes
time_info = dataset._time_info
delattr(dataset, "_time_info")
else:
raise NotImplementedError("Tried to load unsupported scenario.")
print("=====CHECK=====")
print("num_classes:", num_classes, ", num_feats:", num_feats)
print("dataset._train_mask:", dataset._train_masks.shape)
print("dataset._val_mask:", dataset._val_masks.shape)
print("dataset._test_mask:", dataset._test_masks.shape)
print("dataset.labels:", dataset.labels.shape)
if incr_type == 'time':
print("time_info:", time_info is not None)
if incr_type == 'domain':
print("domain_info:", domain_info is not None)
print("===============")
return num_classes, num_feats, dataset, domain_info, time_info
[docs]class GCScenarioLoader(BaseScenarioLoader):
"""
The sceanario loader for graph classification problems.
**Usage example:**
>>> scenario = GCScenarioLoader(dataset_name="ogbg-molhiv", num_tasks=10, metric="rocauc",
... save_path="./data", incr_type="domain", task_shuffle=True)
Bases: ``BaseScenarioLoader``
"""
[docs] def _init_continual_scenario(self):
self.num_classes, self.num_feats, self.__dataset, self.__domain_info, self.__time_splits = load_graph_dataset(self.dataset_name, self.dataset_load_func, self.incr_type, self.save_path)
if self.incr_type in ['class', 'task']:
# determine task configuration
if self.kwargs is not None and 'task_orders' in self.kwargs:
self.__splits = tuple([torch.LongTensor(class_ids) for class_ids in self.kwargs['task_orders']])
elif self.kwargs is not None and 'task_shuffle' in self.kwargs and self.kwargs['task_shuffle']:
self.__splits = torch.split(torch.randperm(self.num_classes), self.num_classes // self.num_tasks)[:self.num_tasks]
else:
self.__splits = torch.split(torch.arange(self.num_classes), self.num_classes // self.num_tasks)[:self.num_tasks]
# compute task ids for each instance
print('class split information:', self.__splits)
id_to_task = self.num_tasks * torch.ones(self.num_classes).long()
for i in range(self.num_tasks):
id_to_task[self.__splits[i]] = i
self.__task_ids = id_to_task[self.__dataset.labels]
self.__original_labels = self.__dataset.labels.clone()
self.__dataset.labels[self.__dataset._test_masks] = -1
elif self.incr_type == 'time':
# compute task ids for each instance
self.num_tasks = self.__time_splits.max().item() + 1
self.__task_ids = self.__time_splits
self.__dataset.labels = self.__dataset.labels.squeeze()
self.__original_labels = self.__dataset.labels.clone()
self.__dataset.labels[self.__dataset._test_masks] = -1
elif self.incr_type == 'domain':
# determine task configuration
self.num_tasks = self.__domain_info.max().item() + 1
if self.kwargs is not None and 'task_shuffle' in self.kwargs and self.kwargs['task_shuffle']:
self.__task_order = torch.randperm(self.num_tasks)
print('domain_order:', self.__task_order)
self.__task_ids = self.__task_order[self.__domain_info]
else:
self.__task_ids = self.__domain_info
self.__original_labels = self.__dataset.labels.clone()
self.__dataset.labels[self.__dataset._test_masks] = -1
# we need to provide task information (only for task-IL)
if self.incr_type == 'task':
self.__task_masks = torch.zeros(self.num_tasks + 1, self.num_classes).bool()
for i in range(self.num_tasks):
self.__task_masks[i, self.__splits[i]] = True
self.__dataset._task_specific_masks = self.__task_masks[self.__task_ids]
# set evaluator for the target scenario
if self.metric is not None:
self.__evaluator = evaluator_map[self.metric](self.num_tasks, self.__task_ids)
self.__test_results = []
[docs] def _update_target_dataset(self):
# create subset for training / validation / test of the current task
target_train_indices = torch.nonzero((self.__task_ids == self._curr_task) & self.__dataset._train_masks, as_tuple=True)[0]
target_val_indices = torch.nonzero((self.__task_ids == self._curr_task) & self.__dataset._val_masks, as_tuple=True)[0]
target_test_indices = torch.nonzero(self.__dataset._test_masks, as_tuple=True)[0]
self._target_dataset = {'train': Subset(self.__dataset, target_train_indices), 'val': Subset(self.__dataset, target_val_indices), 'test': Subset(self.__dataset, target_test_indices)}
[docs] def _update_accumulated_dataset(self):
# create accumulated subset for training / validation / test of the current task
target_train_indices = torch.nonzero((self.__task_ids <= self._curr_task) & self.__dataset._train_masks, as_tuple=True)[0]
target_val_indices = torch.nonzero((self.__task_ids <= self._curr_task) & self.__dataset._val_masks, as_tuple=True)[0]
target_test_indices = torch.nonzero(self.__dataset._test_masks, as_tuple=True)[0]
self._accumulated_dataset = {'train': Subset(self.__dataset, target_train_indices), 'val': Subset(self.__dataset, target_val_indices), 'test': Subset(self.__dataset, target_test_indices)}
[docs] def _get_eval_result_inner(self, preds, target_split):
"""
The inner function of get_eval_result.
Args:
preds (torch.Tensor): predicted output of the current model
target_split (str): target split to measure the performance (spec., 'val' or 'test')
"""
gt = self.__original_labels[self._target_dataset[target_split].indices]
assert preds.shape == gt.shape, "shape mismatch"
return self.__evaluator(preds, gt, self._target_dataset[target_split].indices)
def get_eval_result(self, preds, target_split='test'):
return self._get_eval_result_inner(preds, target_split)
[docs] def get_accum_eval_result(self, preds, target_split='test'):
"""
Compute performance on the accumulated dataset for the given target split.
It can be used to compute train/val performance during training.
Args:
preds (torch.Tensor): predicted output of the current model
target_split (str): target split to measure the performance (spec., 'val' or 'test')
"""
gt = self.__original_labels[self._accumulated_dataset[target_split].indices]
assert preds.shape == gt.shape, "shape mismatch"
return self.__evaluator(preds, gt, self._accumulated_dataset[target_split].indices)
[docs] def get_simple_eval_result(self, curr_batch_preds, curr_batch_gts):
"""
Compute performance for the given batch when we ignore task configuration.
It can be used to compute train/val performance during training.
Args:
curr_batch_preds (torch.Tensor): predicted output of the current model
curr_batch_gts (torch.Tensor): ground-truth labels
"""
return self.__evaluator.simple_eval(curr_batch_preds, curr_batch_gts)
[docs] def next_task(self, preds=torch.empty(1)):
if self.export_mode:
super().next_task(preds)
else:
self.__test_results.append(self._get_eval_result_inner(preds, target_split='test'))
super().next_task(preds)
if self._curr_task == self.num_tasks:
scores = torch.stack(self.__test_results, dim=0)
scores_np = scores.detach().cpu().numpy()
ap = scores_np[-1, :-1].mean().item()
af = (scores_np[np.arange(self.num_tasks), np.arange(self.num_tasks)] - scores_np[-1, :-1]).sum().item() / (self.num_tasks - 1)
if self.initial_test_result is not None:
fwt = (scores_np[np.arange(self.num_tasks-1), np.arange(self.num_tasks-1)+1] - self.initial_test_result.detach().cpu().numpy()[1:-1]).sum() / (self.num_tasks - 1)
else:
fwt = None
return {'exp_results': scores, 'AP': ap, 'AF': af, 'FWT': fwt}
[docs] def get_current_dataset_for_export(self, _global=False):
"""
Returns:
The graph dataset the implemented model uses in the current task
"""
if _global:
metadata = {'num_classes': self.num_classes, 'task': self.__task_ids}
if _global and self.incr_type == 'task': metadata['task_specific_mask'] = self.__task_masks[self.__task_ids]
metadata['graphs'] = []
metadata['labels'] = []
for i in tqdm.trange(len(self.__dataset)):
if self.incr_type != 'task':
g, l = self.__dataset[i]
else:
g, l, _ = self.__dataset[i]
g_data = {}
g_data['edges'] = g.edges()
g_data['ndata_feat'] = g.ndata['feat']
if 'feat' in g.edata: g_data['edata_feat'] = g.edata['feat']
metadata['graphs'].append(g_data)
metadata['labels'].append(l)
metadata['labels'] = torch.LongTensor(metadata['labels'])
metadata['train_mask'] = self.__dataset._train_masks
metadata['val_mask'] = self.__dataset._val_masks
metadata['test_mask'] = self.__dataset._test_masks
metadata['test_indices'] = torch.nonzero(self.__dataset._test_masks, as_tuple=True)[0]
else:
metadata = {}
metadata['train_indices'] = torch.nonzero((self.__task_ids == self._curr_task) & self.__dataset._train_masks, as_tuple=True)[0]
metadata['val_indices'] = torch.nonzero((self.__task_ids == self._curr_task) & self.__dataset._val_masks, as_tuple=True)[0]
return metadata