Source code for begin.scenarios.links

import torch
import dgl
import os
import pickle
import copy
from dgl.data.utils import download, Subset
from ogb.linkproppred import DglLinkPropPredDataset

from .common import BaseScenarioLoader
from .datasets import *
from . import evaluator_map

def load_linkp_dataset(dataset_name, dataset_load_func, incr_type, save_path):
    neg_edges = {}
    if dataset_load_func is not None:
        custom_dataset = dataset_load_func(save_path=save_path)
        graph = custom_dataset['graph']
        num_feats = custom_dataset['num_feats']
        tvt_splits = custom_dataset['tvt_splits']
        neg_edges = custom_dataset['neg_edges']
        tvt_splits[tvt_splits == 1] = 8
        tvt_splits[tvt_splits == 2] = 9
    if dataset_name in ['ogbl-collab'] and incr_type in ['time']:
        dataset = DglLinkPropPredDataset('ogbl-collab', root=save_path)
        # load edges and negative edges
        split_edge = dataset.get_edge_split()
        train_graph = dataset[0]
        combined = {}
        for k in split_edge["train"].keys():
            combined[k] = torch.cat((split_edge["train"][k], split_edge["valid"][k], split_edge["test"][k]), dim=0)
            original = combined[k]
            if k == 'edge':
                rev_edges = torch.cat((combined['edge'][:, 1:2], combined['edge'][:, 0:1]), dim=-1)
                combined[k] = torch.cat((combined[k], rev_edges), dim=-1).view(-1, 2)
            else:
                combined[k] = torch.repeat_interleave(combined[k], 2, dim=0)
        
        # generate graphs with all edges (including val/test)
        graph = dgl.graph((combined['edge'][:, 0], combined['edge'][:, 1]), num_nodes=train_graph.num_nodes())
        for k in combined.keys():
            if k != 'edge':
                if k == 'year': graph.edata['time'] = torch.clamp(combined[k] - 1970, 0, 20000)
                else: graph.edata[k] = combined[k]
        for k in train_graph.ndata.keys():
            graph.ndata[k] = train_graph.ndata[k]
        _srcs, _dsts = map(lambda x: x.numpy().tolist(), graph.edges())
        edgeset = {(s, d) for s, d in zip(_srcs, _dsts)}
        
        num_feats = graph.ndata['feat'].shape[-1]
        # load time split and train/val/test split information
        pkl_path = os.path.join(save_path, f'ogbl-collab_metadata_timeIL.pkl')
        download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/ogbl-collab_metadata_timeIL.pkl', pkl_path)
        metadata = pickle.load(open(pkl_path, 'rb'))
        tvt_splits = metadata['inner_tvt_splits']    
        # choose negative edges
        neg_edges['val'] = torch.LongTensor([[_s, _d] for _s, _d in zip(*zip(*split_edge['valid']['edge_neg'].numpy().tolist())) if (_s, _d) not in edgeset])
        neg_edges['test'] = torch.LongTensor([[_s, _d] for _s, _d in zip(*zip(*split_edge['test']['edge_neg'].numpy().tolist())) if (_s, _d) not in edgeset])
        
    elif dataset_name in ['wikics'] and incr_type in ['domain']:
        dataset = WikiCSLinkDataset(raw_dir=save_path)
        graph = dataset._g
        num_feats = graph.ndata['feat'].shape[-1]
        # load tvt_splits and negative edges
        pkl_path = os.path.join(save_path, f'wikics_metadata_domainIL.pkl')
        download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/wikics_metadata_domainIL.pkl', pkl_path)
        metadata = pickle.load(open(pkl_path, 'rb'))
        tvt_splits = metadata['inner_tvt_splits']
        neg_edges = metadata['neg_edges']
        
        num_tasks = 54
        task_map = torch.LongTensor([[0,1,2,3,4,5,-1,6,7,8],
                                     [-1,9,10,11,12,13,14,15,16,17],
                                     [-1,-1,18,19,20,21,22,23,24,25],
                                     [-1,-1,-1,26,27,28,29,30,31,32],
                                     [-1,-1,-1,-1,33,34,35,36,37,38],
                                     [-1,-1,-1,-1,-1,39,40,41,42,43],
                                     [-1,-1,-1,-1,-1,-1,44,45,46,47],
                                     [-1,-1,-1,-1,-1,-1,-1,48,49,50],
                                     [-1,-1,-1,-1,-1,-1,-1,-1,51,52],
                                     [-1,-1,-1,-1,-1,-1,-1,-1,-1,53]])
        domain_info = graph.ndata.pop('domain')
        srcs, dsts = graph.edges()
        graph.edata['domain'] = task_map[torch.min(domain_info[srcs], domain_info[dsts]), torch.max(domain_info[srcs], domain_info[dsts])]
    elif dataset_name in ['askubuntu'] and incr_type in ['time']:
        dataset = AskUbuntuDataset(dataset_name=dataset_name, raw_dir=save_path)
        graph = dataset.graphs[0]
        num_feats = graph.ndata['feat'].shape[-1]
        
        pkl_path = os.path.join(save_path, f'askubuntu_metadata_timeIL.pkl')
        download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/askubuntu_metadata_timeIL.pkl', pkl_path)
        metadata = pickle.load(open(pkl_path, 'rb'))
        tvt_splits = torch.repeat_interleave(metadata['inner_tvt_splits'], 2, dim=0)
        neg_edges = metadata['neg_edges']
    elif dataset_name in ['facebook'] and incr_type in ['domain']:
        dataset = FacebookLinkDataset(dataset_name=dataset_name, raw_dir=save_path)
        graph = dataset.graphs[0]
        num_feats = graph.ndata['feat'].shape[-1]
        pkl_path = os.path.join(save_path, f'facebook_metadata_domainIL.pkl')
        download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/facebook_metadata_domainIL.pkl', pkl_path)
        metadata = pickle.load(open(pkl_path, 'rb'))
        tvt_splits = torch.repeat_interleave(metadata['inner_tvt_splits'], 2, dim=0)
        neg_edges = metadata['neg_edges']
    else:
        raise NotImplementedError("Tried to load unsupported scenario.")
    
    print("=====CHECK=====")
    print("num_feats:", num_feats)
    print("inner_tvt_splits:", tvt_splits.shape)
    print("neg_edges['val']:", neg_edges['val'].shape)
    print("neg_edges['test']:", neg_edges['test'].shape)
    if incr_type == 'time':
        print("graph.edata['time']", graph.edata['time'].shape)
    if incr_type == 'domain':
        print("graph.edata['domain']", graph.edata['domain'].shape)
    print("===============")
    
    return num_feats, graph, tvt_splits, neg_edges

[docs]class LPScenarioLoader(BaseScenarioLoader): """ The sceanario loader for link prediction. **Usage example:** >>> scenario = LPScenarioLoader(dataset_name="ogbl-collab", num_tasks=3, metric="hits@50", ... save_path="./data", incr_type="time", task_shuffle=True) Bases: ``BaseScenarioLoader`` """
[docs] def _init_continual_scenario(self): self.num_feats, self.__graph, self.__inner_tvt_splits, self.__neg_edges = load_linkp_dataset(self.dataset_name, self.dataset_load_func, self.incr_type, self.save_path) self.num_classes = 1 if self.incr_type in ['class', 'task']: # It is impossible to make class-IL and task-IL setting raise NotImplementedError elif self.incr_type == 'time': self.num_tasks = self.__graph.edata['time'].max().item() + 1 self.__task_ids = self.__graph.edata['time'] elif self.incr_type == 'domain': self.num_tasks = self.__graph.edata['domain'].max().item() + 1 self.__task_ids = self.__graph.edata['domain'] if self.kwargs is not None and 'task_shuffle' in self.kwargs and self.kwargs['task_shuffle']: domain_order = torch.randperm(self.num_tasks) else: domain_order = torch.arange(self.num_tasks) print('domain_order:', domain_order) domain_order_inv = torch.arange(self.num_tasks + 1) domain_order_inv[domain_order] = torch.arange(self.num_tasks) self.__graph.edata['domain'][self.__graph.edata['domain'] < 0] = self.num_tasks self.__task_ids = domain_order_inv[self.__graph.edata['domain']] # set evaluator for the target scenario if self.metric is not None: if '@' in self.metric: metric_name, metric_k = self.metric.split('@') self.__evaluator = evaluator_map[metric_name](self.num_tasks, int(metric_k)) else: self.__evaluator = evaluator_map[self.metric](self.num_tasks, self.__task_ids) self.__test_results = []
[docs] def _update_target_dataset(self): # get sources and destinations srcs, dsts = self.__graph.edges() # note that the edges are bi-directed is_even = ((torch.arange(self.__inner_tvt_splits.shape[0]) % 2) == 0) # train/val/test - 8:1:1 random split edges_for_train = (self.__inner_tvt_splits < 8) if self.incr_type == 'time': edges_for_train &= (self.__task_ids <= self._curr_task) edges_ready = {'val': ((self.__inner_tvt_splits == 8) & (self.__task_ids == self._curr_task)) & is_even, 'test': (self.__inner_tvt_splits > 8) & is_even} # generate data using only train edges target_dataset = dgl.graph((srcs[edges_for_train], dsts[edges_for_train]), num_nodes=self.__graph.num_nodes()) for k in self.__graph.ndata.keys(): if (k != 'time' or k != 'domain'): target_dataset.ndata[k] = self.__graph.ndata[k] for k in self.__graph.edata.keys(): if (k != 'time' or k != 'domain'): target_dataset.edata[k] = self.__graph.edata[k][edges_for_train] # prepare val/test data for current task (containing negative edges) target_edges = {_split: torch.stack((srcs[edges_ready[_split]], dsts[edges_ready[_split]]), dim=-1) for _split in ['val', 'test']} gt_labels = {_split: torch.cat((self.__task_ids[edges_ready[_split]] + 1, torch.zeros(self.__neg_edges[_split].shape[0], dtype=torch.long)), dim=0) for _split in ['val', 'test']} randperms = {_split: torch.randperm(gt_labels[_split].shape[0]) for _split in ['val', 'test']} target_edges = {_split: torch.cat((target_edges[_split], self.__neg_edges[_split]), dim=0)[randperms[_split]] for _split in ['val', 'test']} # generate train/val/test dataset for current task edges_ready['train'] = (edges_for_train & is_even) & (self.__task_ids == self._curr_task) target_edges['train'] = torch.stack((srcs[edges_ready['train']], dsts[edges_ready['train']]), dim=-1) self.__target_labels = {_split: gt_labels[_split][randperms[_split]] for _split in ['val', 'test']} self._target_dataset = {'graph': dgl.add_self_loop(target_dataset), 'train': {'edge': target_edges['train']}, 'val': {'edge': target_edges['val'], 'label': (self.__target_labels['val'] > 0).long()}, 'test': {'edge': target_edges['test'], 'label': -torch.ones_like(self.__target_labels['test'])}} self._target_dataset['train']['label'] = torch.ones(self._target_dataset['train']['edge'].shape[0], dtype=torch.long)
[docs] def _update_accumulated_dataset(self): # get sources and destinations srcs, dsts = self.__graph.edges() # note that the edges are bi-directed is_even = ((torch.arange(self.__inner_tvt_splits.shape[0]) % 2) == 0) # train/val/test - 8:1:1 random split edges_for_train = (self.__inner_tvt_splits < 8) if self.incr_type == 'time': edges_for_train &= (self.__task_ids <= self._curr_task) edges_ready = {'val': ((self.__inner_tvt_splits == 8) & (self.__task_ids <= self._curr_task)) & is_even, 'test': (self.__inner_tvt_splits > 8) & is_even} target_dataset = dgl.graph((srcs[edges_for_train], dsts[edges_for_train]), num_nodes=self.__graph.num_nodes()) for k in self.__graph.ndata.keys(): if (k != 'time' or k != 'domain'): target_dataset.ndata[k] = self.__graph.ndata[k] for k in self.__graph.edata.keys(): if (k != 'time' or k != 'domain'): target_dataset.edata[k] = self.__graph.edata[k][edges_for_train] # prepare val/test data for current task (containing negative edges) target_edges = {_split: torch.stack((srcs[edges_ready[_split]], dsts[edges_ready[_split]]), dim=-1) for _split in ['val']} gt_labels = {_split: torch.cat((self.__task_ids[edges_ready[_split]] + 1, torch.zeros(self.__neg_edges[_split].shape[0], dtype=torch.long)), dim=0) for _split in ['val']} randperms = {_split: torch.randperm(gt_labels[_split].shape[0]) for _split in ['val']} target_edges = {_split: torch.cat((target_edges[_split], self.__neg_edges[_split]), dim=0)[randperms[_split]] for _split in ['val']} # generate train/val/test dataset for current task edges_ready['train'] = (edges_for_train & is_even) & (self.__task_ids <= self._curr_task) target_edges['train'] = torch.stack((srcs[edges_ready['train']], dsts[edges_ready['train']]), dim=-1) self.__accumulated_labels = {_split: gt_labels[_split][randperms[_split]] for _split in ['val']} self.__accumulated_labels['test'] = self.__target_labels['test'] self._accumulated_dataset = {'graph': dgl.add_self_loop(target_dataset), 'train': {'edge': target_edges['train']}, 'val': {'edge': target_edges['val'], 'label': (self.__accumulated_labels['val'] > 0).long()}, 'test': self._target_dataset['test']} self._accumulated_dataset['train']['label'] = torch.ones(self._accumulated_dataset['train']['edge'].shape[0], dtype=torch.long)
[docs] def _get_eval_result_inner(self, preds, target_split): """ The inner function of get_eval_result. Args: preds (torch.Tensor): predicted output of the current model target_split (str): target split to measure the performance (spec., 'val' or 'test') """ gt = (self.__target_labels[target_split] > 0).long() assert preds.shape == gt.shape, "shape mismatch" return self.__evaluator(preds, gt, self.__target_labels[target_split] - 1)
def get_eval_result(self, preds, target_split='test'): return self._get_eval_result_inner(preds, target_split)
[docs] def get_accum_eval_result(self, preds, target_split='test'): """ Compute performance on the accumulated dataset for the given target split. It can be used to compute train/val performance during training. Args: preds (torch.Tensor): predicted output of the current model target_split (str): target split to measure the performance (spec., 'val' or 'test') """ gt = (self.__accumulated_labels[target_split] > 0).long() assert preds.shape == gt.shape, "shape mismatch" return self.__evaluator(preds, gt, self.__accumulated_labels[target_split] - 1)
[docs] def get_simple_eval_result(self, curr_batch_preds, curr_batch_gts): """ Compute performance for the given batch when we ignore task configuration. It can be used to compute train/val performance during training. Args: curr_batch_preds (torch.Tensor): predicted output of the current model curr_batch_gts (torch.Tensor): ground-truth labels """ return self.__evaluator.simple_eval(curr_batch_preds, curr_batch_gts)
[docs] def next_task(self, preds=torch.empty(1)): if self.export_mode: super().next_task(preds) else: self.__test_results.append(self._get_eval_result_inner(preds, target_split='test')) super().next_task(preds) if self._curr_task == self.num_tasks: scores = torch.stack(self.__test_results, dim=0) scores_np = scores.detach().cpu().numpy() ap = scores_np[-1, :-1].mean().item() af = (scores_np[np.arange(self.num_tasks), np.arange(self.num_tasks)] - scores_np[-1, :-1]).sum().item() / (self.num_tasks - 1) if self.initial_test_result is not None: fwt = (scores_np[np.arange(self.num_tasks-1), np.arange(self.num_tasks-1)+1] - self.initial_test_result.detach().cpu().numpy()[1:-1]).sum() / (self.num_tasks - 1) else: fwt = None return {'exp_results': scores, 'AP': ap, 'AF': af, 'FWT': fwt}
[docs] def get_current_dataset_for_export(self, _global=False): """ Returns: The graph dataset the implemented model uses in the current task """ target_graph = self.__graph if _global else self._target_dataset if _global: metadata = {'ndata_feat': self.__graph.ndata['feat'], 'task': self.__task_ids} metadata['edges'] = self.__graph.edges() metadata['neg_edges'] = self.__neg_edges metadata['test_edges'] = self._target_dataset['test']['edge'] metadata['test_labels'] = self.__target_labels['test'] else: metadata = {} metadata['edges'] = target_graph['graph'].edges() metadata['train_edges'] = target_graph['train']['edge'] metadata['train_labels'] = target_graph['train']['label'] metadata['val_edges'] = target_graph['val']['edge'] metadata['val_labels'] = target_graph['val']['label'] return metadata
def load_linkc_dataset(dataset_name, dataset_load_func, incr_type, save_path): if dataset_load_func is not None: custom_dataset = dataset_load_func(save_path=save_path) graph = custom_dataset['graph'] num_feats = custom_dataset['num_feats'] num_classes = custom_dataset['num_classes'] if dataset_name == 'bitcoin' and incr_type in ['task', 'class', 'time']: dataset = BitcoinOTCDataset(dataset_name, raw_dir=save_path) graph = dataset[0] num_feats = graph.ndata['feat'].shape[-1] if incr_type == 'time': num_classes = 1 num_tasks = 7 # make 7 chunks (with same size) for making 7 tasks counts = torch.cumsum(torch.bincount(graph.edata['time']), dim=-1) task_ids = (counts / ((graph.num_edges() + 1.) / num_tasks)).long() time_info = task_ids[graph.edata['time']] # to formulate binary classification problem graph.edata['label'] = (graph.edata.pop('label') < 0).long() graph.edata['time'] = time_info else: num_classes = 6 label_to_class = torch.LongTensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 2, 3, 4, 5, 5, 5, 5, 5]) # for balanced split graph.edata['label'] = label_to_class[graph.edata.pop('label').squeeze(-1) + 10] pkl_path = os.path.join(save_path, f'bitcoin_metadata_allIL.pkl') download(f'https://github.com/anonymous-submission-24/BeGin-iclr24/raw/main/metadata/bitcoin_metadata_allIL.pkl', pkl_path) metadata = pickle.load(open(pkl_path, 'rb')) graph.edata['train_mask'] = ((metadata['inner_tvt_split'] % 10) < 8) graph.edata['val_mask'] = ((metadata['inner_tvt_split'] % 10) == 8) graph.edata['test_mask'] = ((metadata['inner_tvt_split'] % 10) > 8) else: raise NotImplementedError("Tried to load unsupported scenario.") print("=====CHECK=====") print("num_classes:", num_classes, ", num_feats:", num_feats) print("graph.edata['train_mask']:", 'train_mask' in graph.edata) print("graph.edata['val_mask']:", 'val_mask' in graph.edata) print("graph.edata['test_mask']:", 'test_mask' in graph.edata) print("graph.edata['label']:", 'label' in graph.edata) if incr_type == 'time': print("graph.edata['time']:", 'time' in graph.edata) if incr_type == 'domain': print("graph.edata['domain']:", 'domain' in graph.edata) print("===============") return num_classes, num_feats, graph
[docs]class LCScenarioLoader(BaseScenarioLoader): """ The sceanario loader for link classification. **Usage example:** >>> scenario = LCScenarioLoader(dataset_name="bitcoin", num_tasks=3, metric="accuracy", ... save_path="./data", incr_type="task", task_shuffle=True) >>> scenario = LCScenarioLoader(dataset_name="bitcoin", num_tasks=7, metric="aucroc", ... save_path="./data", incr_type="time") Bases: ``BaseScenarioLoader`` """
[docs] def _init_continual_scenario(self): self.num_classes, self.num_feats, self.__graph = load_linkc_dataset(self.dataset_name, self.dataset_load_func, self.incr_type, self.save_path) if 'domain' in self.__graph.edata: self.__domain_info = self.__graph.edata['domain'] if 'time' in self.__graph.edata: self.__time_splits = self.__graph.edata['time'] if self.incr_type in ['domain']: raise NotImplementedError elif self.incr_type == 'time': # split into tasks using timestamp self.num_tasks = self.__time_splits.max().item() + 1 print('num_tasks:', self.num_tasks) self.__task_ids = self.__time_splits elif self.incr_type in ['class', 'task']: # determine task configuration if self.kwargs is not None and 'task_orders' in self.kwargs: self.__splits = tuple([torch.LongTensor(class_ids) for class_ids in self.kwargs['task_orders']]) elif self.kwargs is not None and 'task_shuffle' in self.kwargs and self.kwargs['task_shuffle']: self.__splits = torch.split(torch.randperm(self.num_classes), self.num_classes // self.num_tasks)[:self.num_tasks] else: self.__splits = torch.split(torch.arange(self.num_classes), self.num_classes // self.num_tasks)[:self.num_tasks] print('class split information:', self.__splits) # compute task ids for each instance and remove time information (since it is unnecessary) id_to_task = self.num_tasks * torch.ones(self.__graph.edata['label'].max() + 1).long() for i in range(self.num_tasks): id_to_task[self.__splits[i]] = i self.__task_ids = id_to_task[self.__graph.edata['label']] # ignore classes which are not used in the tasks self.__graph.edata['test_mask'] = self.__graph.edata['test_mask'] & (self.__task_ids < self.num_tasks) # we need to provide task information (only for task-IL) if self.incr_type == 'task': self.__task_masks = torch.zeros(self.num_tasks + 1, self.num_classes).bool() for i in range(self.num_tasks): self.__task_masks[i, self.__splits[i]] = True # set evaluator for the target scenario if self.metric is not None: self.__evaluator = evaluator_map[self.metric](self.num_tasks, self.__task_ids) self.__test_results = []
[docs] def _update_target_dataset(self): target_dataset = copy.deepcopy(self.__graph) # set train/val/test split for the current task target_dataset.edata['train_mask'] = self.__graph.edata['train_mask'] & (self.__task_ids == self._curr_task) target_dataset.edata['val_mask'] = self.__graph.edata['val_mask'] & (self.__task_ids == self._curr_task) target_dataset.edata['test_mask'] = self.__graph.edata['test_mask'] # hide labels of test nodes target_dataset.edata['label'] = self.__graph.edata['label'].clone() target_dataset.edata['label'][target_dataset.edata['test_mask'] | (self.__task_ids != self._curr_task)] = -1 if self.incr_type == 'class': # for class-IL, no need to change self._target_dataset = target_dataset elif self.incr_type == 'task': # for task-IL, we need task information. BeGin provide the information with 'task_specific_mask' self._target_dataset = target_dataset self._target_dataset.edata['task_specific_mask'] = self.__task_masks[self.__task_ids] elif self.incr_type == 'time': # for time-IL, we need to hide unseen nodes and information at the current timestamp srcs, dsts = target_dataset.edges() edges_ready = (self.__task_ids <= self._curr_task) self._target_dataset = dgl.graph((srcs[edges_ready], dsts[edges_ready]), num_nodes=self.__graph.num_nodes()) for k in target_dataset.ndata.keys(): self._target_dataset.ndata[k] = target_dataset.ndata[k] for k in target_dataset.edata.keys(): self._target_dataset.edata[k] = target_dataset.edata[k][edges_ready] elif self.incr_type == 'domain': pass
[docs] def _update_accumulated_dataset(self): target_dataset = copy.deepcopy(self.__graph) # set train/val/test split for the current task target_dataset.edata['train_mask'] = self.__graph.edata['train_mask'] & (self.__task_ids <= self._curr_task) target_dataset.edata['val_mask'] = self.__graph.edata['val_mask'] & (self.__task_ids <= self._curr_task) target_dataset.edata['test_mask'] = self.__graph.edata['test_mask'] # hide labels of test nodes target_dataset.edata['label'] = self.__graph.edata['label'].clone() target_dataset.edata['label'][target_dataset.edata['test_mask'] | (self.__task_ids > self._curr_task)] = -1 if self.incr_type == 'class': # for class-IL, no need to change self._accumulated_dataset = target_dataset elif self.incr_type == 'task': # for task-IL, we need task information. BeGin provide the information with 'task_specific_mask' self._accumulated_dataset = target_dataset self._accumulated_dataset.edata['task_specific_mask'] = self.__task_masks[self.__task_ids] elif self.incr_type == 'time': # for time-IL, we need to hide unseen nodes and information at the current timestamp srcs, dsts = target_dataset.edges() edges_ready = (self.__task_ids <= self._curr_task) self._accumulated_dataset = dgl.graph((srcs[edges_ready], dsts[edges_ready]), num_nodes=self.__graph.num_nodes()) for k in target_dataset.ndata.keys(): self._accumulated_dataset.ndata[k] = target_dataset.ndata[k] for k in target_dataset.edata.keys(): self._accumulated_dataset.edata[k] = target_dataset.edata[k][edges_ready] elif self.incr_type == 'domain': pass
[docs] def _get_eval_result_inner(self, preds, target_split): """ The inner function of get_eval_result. Args: preds (torch.Tensor): predicted output of the current model target_split (str): target split to measure the performance (spec., 'val' or 'test') """ if self.incr_type == 'time': # for Time-IL we evaluate the performance only with currently seen nodes gt = self.__graph.edata['label'][self.__task_ids <= self._curr_task][self._target_dataset.edata[target_split + '_mask']] assert preds.shape == gt.shape, "shape mismatch" return self.__evaluator(preds, gt, torch.arange(self.__graph.num_edges())[self.__task_ids <= self._curr_task][self._target_dataset.edata[target_split + '_mask']]) else: gt = self.__graph.edata['label'][self._target_dataset.edata[target_split + '_mask']] assert preds.shape == gt.shape, "shape mismatch" return self.__evaluator(preds, gt, torch.arange(self._target_dataset.num_edges())[self._target_dataset.edata[target_split + '_mask']])
def get_eval_result(self, preds, target_split='test'): return self._get_eval_result_inner(preds, target_split)
[docs] def get_accum_eval_result(self, preds, target_split='test'): """ Compute performance on the accumulated dataset for the given target split. It can be used to compute train/val performance during training. Args: preds (torch.Tensor): predicted output of the current model target_split (str): target split to measure the performance (spec., 'val' or 'test') """ if self.incr_type == 'time': # for Time-IL we evaluate the performance only with currently seen nodes gt = self.__graph.edata['label'][self.__task_ids <= self._curr_task][self._accumulated_dataset.edata[target_split + '_mask']] assert preds.shape == gt.shape, "shape mismatch" return self.__evaluator(preds, gt, torch.arange(self.__graph.num_edges())[self.__task_ids <= self._curr_task][self._accumulated_dataset.edata[target_split + '_mask']]) else: gt = self.__graph.edata['label'][self._accumulated_dataset.edata[target_split + '_mask']] assert preds.shape == gt.shape, "shape mismatch" return self.__evaluator(preds, gt, torch.arange(self._target_dataset.num_edges())[self._accumulated_dataset.edata[target_split + '_mask']])
[docs] def get_simple_eval_result(self, curr_batch_preds, curr_batch_gts): """ Compute performance for the given batch when we ignore task configuration. It can be used to compute train/val performance during training. Args: curr_batch_preds (torch.Tensor): predicted output of the current model curr_batch_gts (torch.Tensor): ground-truth labels """ return self.__evaluator.simple_eval(curr_batch_preds, curr_batch_gts)
[docs] def next_task(self, preds=torch.empty(1)): if self.export_mode: super().next_task(preds) else: self.__test_results.append(self._get_eval_result_inner(preds, target_split='test')) super().next_task(preds) if self._curr_task == self.num_tasks: scores = torch.stack(self.__test_results, dim=0) scores_np = scores.detach().cpu().numpy() ap = scores_np[-1, :-1].mean().item() af = (scores_np[np.arange(self.num_tasks), np.arange(self.num_tasks)] - scores_np[-1, :-1]).sum().item() / (self.num_tasks - 1) if self.initial_test_result is not None: fwt = (scores_np[np.arange(self.num_tasks-1), np.arange(self.num_tasks-1)+1] - self.initial_test_result.detach().cpu().numpy()[1:-1]).sum() / (self.num_tasks - 1) else: fwt = None return {'exp_results': scores, 'AP': ap, 'AF': af, 'FWT': fwt}
[docs] def get_current_dataset_for_export(self, _global=False): """ Returns: The graph dataset the implemented model uses in the current task """ target_graph = self.__graph if _global else self._target_dataset metadata = {'num_classes': self.num_classes, 'ndata_feat': self.__graph.ndata['feat'], 'task': self.__task_ids} if _global else {} if _global and self.incr_type == 'task': metadata['task_specific_mask'] = self.__task_masks[self.__task_ids] metadata['edges'] = target_graph.edges() metadata['train_mask'] = target_graph.edata['train_mask'] metadata['val_mask'] = target_graph.edata['val_mask'] if _global: metadata['test_mask'] = target_graph.edata['test_mask'] metadata['label'] = copy.deepcopy(target_graph.edata['label']) return metadata