import torch
import dgl
import os
import pickle
import copy
from dgl.data.utils import download, Subset
from ogb.nodeproppred import DglNodePropPredDataset
from torch_scatter import scatter
from .common import BaseScenarioLoader
from .datasets import *
from . import evaluator_map
[docs]def load_node_dataset(dataset_name, dataset_load_func, incr_type, save_path):
"""
The function for load node-level datasets.
"""
cover_rule = {'feat': 'node', 'label': 'node', 'train_mask': 'node', 'val_mask': 'node', 'test_mask': 'node'}
if dataset_load_func is not None:
custom_dataset = dataset_load_func(save_path=save_path)
graph = custom_dataset['graph']
num_feats = custom_dataset['num_feats']
num_classes = custom_dataset['num_classes']
elif dataset_name in ['cora'] and incr_type in ['task', 'class']:
dataset = dgl.data.CoraGraphDataset(raw_dir=save_path, verbose=False)
graph = dataset._g
num_feats, num_classes = graph.ndata['feat'].shape[-1], dataset.num_classes
elif dataset_name in ['citeseer'] and incr_type in ['task', 'class']:
dataset = dgl.data.CiteseerGraphDataset(raw_dir=save_path)
graph = dataset._g
num_feats, num_classes = graph.ndata['feat'].shape[-1], dataset.num_classes
elif dataset_name in ['corafull'] and incr_type in ['task', 'class']:
dataset = dgl.data.CoraFullDataset(raw_dir=save_path, verbose=False)
graph = dataset._graph
num_feats, num_classes = graph.ndata['feat'].shape[-1], dataset.num_classes
# We need to designate train/val/test split since DGL does not provide the information.
# We used random train/val/test split (6 : 2 : 2)
pkl_path = os.path.join(save_path, f'corafull_metadata_allIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/corafull_metadata_allIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits'] % 10
graph.ndata['train_mask'] = (inner_tvt_splits < 6)
graph.ndata['val_mask'] = (6 <= inner_tvt_splits) & (inner_tvt_splits < 8)
graph.ndata['test_mask'] = (8 <= inner_tvt_splits)
elif dataset_name in ['ogbn-arxiv'] and incr_type in ['task', 'class', 'time']:
dataset = DglNodePropPredDataset('ogbn-arxiv', root=save_path)
graph, label = dataset[0]
num_feats, num_classes = graph.ndata['feat'].shape[-1], dataset.num_classes
# to_bidirected
srcs, dsts = graph.all_edges()
graph.add_edges(dsts, srcs)
if incr_type == 'time':
# load train/val/test split
pkl_path = os.path.join(save_path, f'ogbn-arxiv_metadata_timeIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/ogbn-arxiv_metadata_timeIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
graph.ndata['train_mask'] = (inner_tvt_splits < 4)
graph.ndata['val_mask'] = (inner_tvt_splits == 4)
graph.ndata['test_mask'] = (inner_tvt_splits > 4)
# time information for splitting tasks
graph.ndata['time'] = torch.clamp(graph.ndata.pop('year').squeeze() - 1997, 0, 20000)
else:
# load train/val/test split
split_idx = dataset.get_idx_split()
for _split, _split_name in [('train', 'train'), ('valid', 'val'), ('test', 'test')]:
_indices = torch.zeros(graph.num_nodes(), dtype=torch.bool)
_indices[split_idx[_split]] = True
graph.ndata[_split_name + '_mask'] = _indices
# load target label and timestamp information
graph.ndata['label'] = label.squeeze()
elif dataset_name in ['ogbn-products'] and incr_type in ['task', 'class']:
dataset = DglNodePropPredDataset('ogbn-products', root=save_path)
graph, label = dataset[0]
num_feats, num_classes = graph.ndata['feat'].shape[-1], dataset.num_classes
# load train/val/test split
split_idx = dataset.get_idx_split()
for _split, _split_name in [('train', 'train'), ('valid', 'val'), ('test', 'test')]:
_indices = torch.zeros(graph.num_nodes(), dtype=torch.bool)
_indices[split_idx[_split]] = True
graph.ndata[_split_name + '_mask'] = _indices
# load target label and timestamp information
graph.ndata['label'] = label.squeeze()
elif dataset_name in ['ogbn-proteins'] and incr_type in ['domain']:
dataset = DglNodePropPredDataset('ogbn-proteins', root=save_path)
graph, label = dataset[0]
# create node features using edge features + load species information
# (See https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/proteins/gnn.py : commit d04eada)
graph.ndata['feat'] = scatter(graph.edata.pop('feat'), graph.edges()[0], dim=0, reduce='mean')
unique_ids = torch.unique(graph.ndata['species'])
raw_species_to_domain = -torch.ones(unique_ids.max().item() + 1, dtype=torch.long)
raw_species_to_domain[unique_ids] = torch.arange(8)
graph.ndata['species'] = raw_species_to_domain[graph.ndata.pop('species').squeeze(-1)]
num_feats, num_classes = graph.ndata['feat'].shape[-1], label.shape[-1]
# load train/val/test split
pkl_path = os.path.join(save_path, f'ogbn-proteins_metadata_domainIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/ogbn-proteins_metadata_domainIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
graph.ndata['train_mask'] = (inner_tvt_splits < 4)
graph.ndata['val_mask'] = (inner_tvt_splits == 4)
graph.ndata['test_mask'] = (inner_tvt_splits > 4)
# load target label and domain information
graph.ndata['label'] = label
if incr_type == 'domain': graph.ndata['domain'] = graph.ndata.pop('species').squeeze()
elif dataset_name in ['ogbn-mag'] and incr_type in ['task', 'class', 'time']:
dataset = DglNodePropPredDataset('ogbn-mag', root=save_path)
_graph, _label = dataset[0]
srcs, dsts = _graph.edges(etype='cites')
graph = dgl.graph((srcs, dsts))
# pick nodes whose entity is 'paper'
graph.ndata['feat'] = _graph.ndata['feat']['paper']
graph.add_edges(dsts, srcs)
label = _label['paper'].squeeze()
split_idx = dataset.get_idx_split()
# (for task, class) select classes with at least 10 nodes (in train, valid, and test split)
if incr_type in ['task', 'class']:
traincnt = torch.bincount(label[split_idx['train']['paper']])
valcnt = torch.bincount(label[split_idx['valid']['paper']])
testcnt = torch.bincount(label[split_idx['test']['paper']])
considered_labels = torch.nonzero(torch.min(torch.stack((traincnt, valcnt, testcnt), dim=-1), dim=-1).values >= 10, as_tuple=True)[0]
processed_labels = torch.ones(label.max() + 1, dtype=torch.long) * considered_labels.shape[0]
processed_labels[considered_labels] = torch.arange(considered_labels.shape[0])
label = processed_labels[label]
num_feats, num_classes = graph.ndata['feat'].shape[-1], label.max().item() # ignore the last class
# load train/val/test split
for _split, _split_name in [('train', 'train'), ('valid', 'val'), ('test', 'test')]:
_indices = torch.zeros(graph.num_nodes(), dtype=torch.bool)
_indices[split_idx[_split]['paper']] = True
graph.ndata[_split_name + '_mask'] = _indices
graph.ndata['label'] = label.squeeze()
elif incr_type in ['time']:
pkl_path = os.path.join(save_path, f'ogbn-mag_metadata_timeIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/ogbn-mag_metadata_timeIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
graph.ndata['train_mask'] = (inner_tvt_splits < 4)
graph.ndata['val_mask'] = (inner_tvt_splits == 4)
graph.ndata['test_mask'] = (inner_tvt_splits > 4)
graph.ndata['label'] = label.squeeze()
num_feats, num_classes = graph.ndata['feat'].shape[-1], (label.max().item() + 1)
graph.ndata['time'] = _graph.ndata['year']['paper'].squeeze() - 2010
elif dataset_name in ['twitch'] and incr_type in ['domain']:
dataset = TwitchGamerNodeDataset('twitch', raw_dir=save_path)
graph = dataset[0]
num_feats, num_classes = graph.ndata['feat'].shape[-1], dataset.num_classes
pkl_path = os.path.join(save_path, f'twitch_metadata_domainIL.pkl')
download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/twitch_metadata_domainIL.pkl', pkl_path)
metadata = pickle.load(open(pkl_path, 'rb'))
inner_tvt_splits = metadata['inner_tvt_splits']
graph.ndata['train_mask'] = (inner_tvt_splits < 4)
graph.ndata['val_mask'] = (inner_tvt_splits == 4)
graph.ndata['test_mask'] = (inner_tvt_splits > 4)
else:
raise NotImplementedError("Tried to load unsupported scenario.")
# We hide information of unseen nodes (for Time-IL)
for k in graph.ndata.keys():
if k not in cover_rule:
cover_rule[k] = 'node'
for k in graph.edata.keys():
if k not in cover_rule:
cover_rule[k] = 'edge'
# remove and add self-loop (to prevent duplicates)
srcs, dsts = graph.edges()
is_non_loop = (srcs != dsts)
final_graph = dgl.graph((srcs[is_non_loop], dsts[is_non_loop]), num_nodes=graph.num_nodes())
for k in graph.ndata.keys():
final_graph.ndata[k] = graph.ndata[k]
for k in graph.edata.keys():
final_graph.edata[k] = graph.edata[k][is_non_loop]
final_graph = dgl.add_self_loop(final_graph)
print("=====CHECK=====")
print("num_classes:", num_classes, ", num_feats:", num_feats)
print("graph.ndata['train_mask']:", 'train_mask' in graph.ndata)
print("graph.ndata['val_mask']:", 'val_mask' in graph.ndata)
print("graph.ndata['test_mask']:", 'test_mask' in graph.ndata)
print("graph.ndata['label']:", 'label' in graph.ndata)
if incr_type == 'time':
print("graph.ndata['time']:", 'time' in graph.ndata)
if incr_type == 'domain':
print("graph.ndata['domain']:", 'domain' in graph.ndata)
print("===============")
return num_classes, num_feats, final_graph, cover_rule
[docs]class NCScenarioLoader(BaseScenarioLoader):
"""
The sceanario loader for node classification problems.
**Usage example:**
>>> scenario = NCScenarioLoader(dataset_name dataset_object=None, num_tasks=3, metric="accuracy",
... save_path="./data", incr_type="task", task_shuffle=True)
Bases: ``BaseScenarioLoader``
"""
[docs] def _init_continual_scenario(self):
self.num_classes, self.num_feats, self.__graph, self.__cover_rule = load_node_dataset(self.dataset_name, self.dataset_load_func, self.incr_type, self.save_path)
if self.incr_type in ['class', 'task']:
# determine task configuration
if self.kwargs is not None and 'task_orders' in self.kwargs:
self.__splits = tuple([torch.LongTensor(class_ids) for class_ids in self.kwargs['task_orders']])
elif self.kwargs is not None and 'task_shuffle' in self.kwargs and self.kwargs['task_shuffle']:
self.__splits = torch.split(torch.randperm(self.num_classes), self.num_classes // self.num_tasks)[:self.num_tasks]
else:
self.__splits = torch.split(torch.arange(self.num_classes), self.num_classes // self.num_tasks)[:self.num_tasks]
print('class split information:', self.__splits)
# compute task ids for each node
id_to_task = self.num_tasks * torch.ones(self.__graph.ndata['label'].max() + 1).long()
for i in range(self.num_tasks):
id_to_task[self.__splits[i]] = i
self.__task_ids = id_to_task[self.__graph.ndata['label']]
# ignore classes which are not used in the tasks
self.__graph.ndata['test_mask'] = self.__graph.ndata['test_mask'] & (self.__task_ids < self.num_tasks)
elif self.incr_type == 'time':
# compute task ids for each node
self.__task_ids = self.__graph.ndata['time']
if self.num_tasks != self.__task_ids.max().item() + 1:
print("WARNING: Mismatch between the number of tasks and the processed data. Please check again.")
# overwrite num_tasks
self.num_tasks = self.__task_ids.max().item() + 1
elif self.incr_type == 'domain':
# num_tasks only depends on the number of domains
self.num_tasks = self.__graph.ndata['domain'].max().item() + 1
# determine task configuration
if self.kwargs is not None and 'task_shuffle' in self.kwargs and self.kwargs['task_shuffle']:
self.__task_order = torch.randperm(self.num_tasks)
print('domain_order:', self.__task_order)
self.__task_ids = self.__task_order[self.__graph.ndata['domain']]
else:
self.__task_ids = self.__graph.ndata['domain']
# we need to provide task information (only for task-IL)
if self.incr_type == 'task':
self.__task_masks = torch.zeros(self.num_tasks + 1, self.num_classes).bool()
for i in range(self.num_tasks):
self.__task_masks[i, self.__splits[i]] = True
# set evaluator for the target scenario
if self.metric is not None:
self.__evaluator = evaluator_map[self.metric](self.num_tasks, self.__task_ids)
self.__test_results = []
[docs] def _update_target_dataset(self):
target_dataset = self.__graph.clone()
# conceal unnecessary information
for k, v in self.__cover_rule.items():
if v == 'node': target_dataset.ndata.pop(k)
elif v == 'edge': target_dataset.edata.pop(k)
target_dataset.ndata['feat'] = self.__graph.ndata['feat'].clone()
target_dataset.ndata['label'] = self.__graph.ndata['label'].clone()
target_dataset.ndata['train_mask'] = self.__graph.ndata['train_mask'].clone()
target_dataset.ndata['val_mask'] = self.__graph.ndata['val_mask'].clone()
target_dataset.ndata['test_mask'] = self.__graph.ndata['test_mask'].clone()
# update train/val/test mask for the current task
target_dataset.ndata['train_mask'] = target_dataset.ndata['train_mask'] & (self.__task_ids == self._curr_task)
target_dataset.ndata['val_mask'] = target_dataset.ndata['val_mask'] & (self.__task_ids == self._curr_task)
target_dataset.ndata['label'][target_dataset.ndata['test_mask'] | (self.__task_ids > self._curr_task)] = -1
if self.incr_type == 'class':
# for class-IL, no need to change
self._target_dataset = target_dataset
elif self.incr_type == 'task':
# for task-IL, we need task information. BeGin provide the information with 'task_specific_mask'
self._target_dataset = target_dataset
self._target_dataset.ndata['task_specific_mask'] = self.__task_masks[self.__task_ids]
elif self.incr_type == 'time':
# for time-IL, we need to hide unseen nodes and information at the current timestamp
# remain only seen nodes and edges
srcs, dsts = target_dataset.edges()
nodes_ready = self.__task_ids <= self._curr_task
edges_ready = (self.__task_ids[srcs] <= self._curr_task) & (self.__task_ids[dsts] <= self._curr_task)
self._target_dataset = dgl.graph((srcs[edges_ready], dsts[edges_ready]), num_nodes=self.__graph.num_nodes())
# cover the information of the unseen nodes/edges
for k in target_dataset.ndata.keys():
self._target_dataset.ndata[k] = target_dataset.ndata[k]
if self._target_dataset.ndata[k].dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
self._target_dataset.ndata[k][~nodes_ready] = -1
else:
self._target_dataset.ndata[k][~nodes_ready] = 0
for k in target_dataset.edata.keys():
self._target_dataset.edata[k] = target_dataset.edata[k][edges_ready]
# update test mask (exclude unseen test nodes)
self._target_dataset.ndata['test_mask'] = self._target_dataset.ndata['test_mask'] & (self.__task_ids <= self._curr_task)
elif self.incr_type == 'domain':
# for domain-IL, no need to change
self._target_dataset = target_dataset
[docs] def _update_accumulated_dataset(self):
target_dataset = self.__graph.clone()
# conceal unnecessary information
for k, v in self.__cover_rule.items():
if v == 'node': target_dataset.ndata.pop(k)
elif v == 'edge': target_dataset.edata.pop(k)
target_dataset.ndata['feat'] = self.__graph.ndata['feat'].clone()
target_dataset.ndata['label'] = self.__graph.ndata['label'].clone()
target_dataset.ndata['train_mask'] = self.__graph.ndata['train_mask'].clone()
target_dataset.ndata['val_mask'] = self.__graph.ndata['val_mask'].clone()
target_dataset.ndata['test_mask'] = self.__graph.ndata['test_mask'].clone()
# update train/val/test mask for the current task
target_dataset.ndata['train_mask'] = target_dataset.ndata['train_mask'] & (self.__task_ids <= self._curr_task)
target_dataset.ndata['val_mask'] = target_dataset.ndata['val_mask'] & (self.__task_ids <= self._curr_task)
target_dataset.ndata['label'][target_dataset.ndata['test_mask'] | (self.__task_ids > self._curr_task)] = -1
if self.incr_type == 'class':
# for class-IL, no need to change
self._accumulated_dataset = target_dataset
elif self.incr_type == 'task':
# for task-IL, we need task information. BeGin provide the information with 'task_specific_mask'
self._accumulated_dataset = target_dataset
self._accumulated_dataset.ndata['task_specific_mask'] = self.__task_masks[self.__task_ids]
elif self.incr_type == 'time':
# for time-IL, we need to hide unseen nodes and information at the current timestamp
srcs, dsts = target_dataset.edges()
nodes_ready = self.__task_ids <= self._curr_task
edges_ready = (self.__task_ids[srcs] <= self._curr_task) & (self.__task_ids[dsts] <= self._curr_task)
self._accumulated_dataset = dgl.graph((srcs[edges_ready], dsts[edges_ready]), num_nodes=self.__graph.num_nodes())
# cover the information of the unseen nodes/edges
for k in target_dataset.ndata.keys():
self._accumulated_dataset.ndata[k] = target_dataset.ndata[k]
if self._accumulated_dataset.ndata[k].dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
self._accumulated_dataset.ndata[k][~nodes_ready] = -1
else:
self._accumulated_dataset.ndata[k][~nodes_ready] = 0
for k in target_dataset.edata.keys():
self._accumulated_dataset.edata[k] = target_dataset.edata[k][edges_ready]
# update test mask (exclude unseen test nodes)
self._accumulated_dataset.ndata['test_mask'] = self._accumulated_dataset.ndata['test_mask'] & (self.__task_ids <= self._curr_task)
elif self.incr_type == 'domain':
self._accumulated_dataset = target_dataset
[docs] def _get_eval_result_inner(self, preds, target_split):
"""
The inner function of get_eval_result.
Args:
preds (torch.Tensor): predicted output of the current model
target_split (str): target split to measure the performance (spec., 'val' or 'test')
"""
gt = self.__graph.ndata['label'][self._target_dataset.ndata[target_split + '_mask']]
assert preds.shape == gt.shape, "shape mismatch"
return self.__evaluator(preds, gt, torch.arange(self._target_dataset.num_nodes())[self._target_dataset.ndata[target_split + '_mask']])
def get_eval_result(self, preds, target_split='test'):
return self._get_eval_result_inner(preds, target_split)
[docs] def get_accum_eval_result(self, preds, target_split='test'):
"""
Compute performance on the accumulated dataset for the given target split.
It can be used to compute train/val performance during training.
Args:
preds (torch.Tensor): predicted output of the current model
target_split (str): target split to measure the performance (spec., 'val' or 'test')
"""
gt = self.__graph.ndata['label'][self._accumulated_dataset.ndata[target_split + '_mask']]
assert preds.shape == gt.shape, "shape mismatch"
return self.__evaluator(preds, gt, torch.arange(self._accumulated_dataset.num_nodes())[self._accumulated_dataset.ndata[target_split + '_mask']])
[docs] def get_simple_eval_result(self, curr_batch_preds, curr_batch_gts):
"""
Compute performance for the given batch when we ignore task configuration.
It can be used to compute train/val performance during training.
Args:
curr_batch_preds (torch.Tensor): predicted output of the current model
curr_batch_gts (torch.Tensor): ground-truth labels
"""
return self.__evaluator.simple_eval(curr_batch_preds, curr_batch_gts)
[docs] def next_task(self, preds=torch.empty(1)):
if self.export_mode:
super().next_task(preds)
else:
self.__test_results.append(self._get_eval_result_inner(preds, target_split='test'))
super().next_task(preds)
if self._curr_task == self.num_tasks:
scores = torch.stack(self.__test_results, dim=0)
scores_np = scores.detach().cpu().numpy()
ap = scores_np[-1, :-1].mean().item()
af = (scores_np[np.arange(self.num_tasks), np.arange(self.num_tasks)] - scores_np[-1, :-1]).sum().item() / (self.num_tasks - 1)
if self.initial_test_result is not None:
fwt = (scores_np[np.arange(self.num_tasks-1), np.arange(self.num_tasks-1)+1] - self.initial_test_result.detach().cpu().numpy()[1:-1]).sum() / (self.num_tasks - 1)
else:
fwt = None
return {'exp_results': scores, 'AP': ap, 'AF': af, 'FWT': fwt}
[docs] def get_current_dataset_for_export(self, _global=False):
"""
Returns:
The graph dataset the implemented model uses in the current task
"""
target_graph = self.__graph if _global else self._target_dataset
metadata = {'num_classes': self.num_classes, 'ndata_feat': self.__graph.ndata['feat'], 'task': self.__task_ids} if _global else {}
if _global and self.incr_type == 'task': metadata['task_specific_mask'] = self.__task_masks[self.__task_ids]
metadata['edges'] = target_graph.edges()
metadata['train_mask'] = target_graph.ndata['train_mask']
metadata['val_mask'] = target_graph.ndata['val_mask']
if _global: metadata['test_mask'] = target_graph.ndata['test_mask']
metadata['label'] = copy.deepcopy(target_graph.ndata['label'])
return metadata