Source code for begin.algorithms.ewc.links

import sys
import numpy as np
import torch
import copy, dgl
import torch.nn.functional as F
from begin.trainers.links import LCTrainer, LPTrainer

[docs]class LCTaskILEWCTrainer(LCTrainer): def __init__(self, model, scenario, optimizer_fn, loss_fn, device, **kwargs): """ EWC needs `lamb`, the additional hyperparamter for the regularization term used in :func:`afterInference` """ super().__init__(model.to(device), scenario, optimizer_fn, loss_fn, device, **kwargs) self.lamb = kwargs['lamb'] if 'lamb' in kwargs else 10000.
[docs] def prepareLoader(self, _curr_dataset, curr_training_states): """ The event function to generate dataloaders from the given dataset for the current task. For task-IL, we need to additionally consider task information for the inference step. Args: curr_dataset (object): The dataset for the current task. Its type is dgl.graph for node-level and link-level problem, and dgl.data.DGLDataset for graph-level problem. curr_training_states (dict): the dictionary containing the current training states. Returns: A tuple containing three dataloaders. The trainer considers the first dataloader, second dataloader, and third dataloader as dataloaders for training, validation, and test, respectively. """ curr_dataset = copy.deepcopy(_curr_dataset) srcs, dsts = curr_dataset.edges() labels = curr_dataset.edata.pop('label') train_mask = curr_dataset.edata.pop('train_mask') val_mask = curr_dataset.edata.pop('val_mask') test_mask = curr_dataset.edata.pop('test_mask') task_mask = curr_dataset.edata.pop('task_specific_mask') curr_dataset = dgl.add_self_loop(curr_dataset) return [(curr_dataset, srcs[train_mask], dsts[train_mask], task_mask[train_mask], labels[train_mask])], [(curr_dataset, srcs[val_mask], dsts[val_mask], task_mask[val_mask], labels[val_mask])], [(curr_dataset, srcs[test_mask], dsts[test_mask], task_mask[test_mask], labels[test_mask])]
[docs] def inference(self, model, _curr_batch, training_states): """ The event function to execute inference step. For task-IL, we need to additionally consider task information for the inference step. Args: model (torch.nn.Module): the current trained model. curr_batch (object): the data (or minibatch) for the current iteration. curr_training_states (dict): the dictionary containing the current training states. Returns: A dictionary containing the inference results, such as prediction result and loss. """ curr_batch, srcs, dsts, task_masks, labels = _curr_batch preds = model(curr_batch.to(self.device), curr_batch.ndata['feat'].to(self.device), srcs, dsts, task_masks=task_masks) loss = self.loss_fn(preds, labels.to(self.device)) return {'preds': preds, 'loss': loss}
[docs] def afterInference(self, results, model, optimizer, _curr_batch, training_states): """ The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function. EWC performs regularization process in this function. Args: results (dict): the returned dictionary from the event function `inference`. model (torch.nn.Module): the current trained model. optimizer (torch.optim.Optimizer): the current optimizer function. curr_batch (object): the data (or minibatch) for the current iteration. curr_training_states (dict): the dictionary containing the current training states. Returns: A dictionary containing the information from the `results`. """ loss_reg = 0 for _param, _fisher in zip(training_states['params'], training_states['fishers']): for name, p in model.named_parameters(): l = self.lamb * _fisher[name] l = l * ((p - _param[name]) ** 2) loss_reg = loss_reg + l.sum() total_loss = results['loss'] + loss_reg total_loss.backward() optimizer.step() return {'loss': total_loss.item(), 'acc': self.eval_fn(torch.argmax(results['preds'], -1), _curr_batch[-1].to(self.device))}
[docs] def initTrainingStates(self, scenario, model, optimizer): return {'fishers': [], 'params': []}
[docs] def processAfterTraining(self, task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states): """ The event function to execute some processes after training the current task. EWC computes fisher information matrix and stores the learned weights to compute the penalty term in :func:`afterInference` Args: task_id (int): the index of the current task. curr_dataset (object): The dataset for the current task. curr_model (torch.nn.Module): the current trained model. curr_optimizer (torch.optim.Optimizer): the current optimizer function. curr_training_states (dict): the dictionary containing the current training states. """ super().processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states) params = {name: torch.zeros_like(p) for name, p in curr_model.named_parameters()} fishers = {name: torch.zeros_like(p) for name, p in curr_model.named_parameters()} train_loader = self.prepareLoader(curr_dataset, curr_training_states)[0] total_num_items = 0 for i, _curr_batch in enumerate(iter(train_loader)): curr_model.zero_grad() curr_results = self.inference(curr_model, _curr_batch, curr_training_states) curr_results['loss'].backward() curr_num_items =_curr_batch[1].shape[0] total_num_items += curr_num_items for name, p in curr_model.named_parameters(): params[name] = p.data.clone().detach() fishers[name] += (p.grad.data.clone().detach() ** 2) * curr_num_items for name, p in curr_model.named_parameters(): fishers[name] /= total_num_items curr_training_states['fishers'].append(fishers) curr_training_states['params'].append(params)
[docs]class LCClassILEWCTrainer(LCTrainer): def __init__(self, model, scenario, optimizer_fn, loss_fn, device, **kwargs): """ EWC needs `lamb`, the additional hyperparamter for the regularization term used in :func:`afterInference` """ super().__init__(model.to(device), scenario, optimizer_fn, loss_fn, device, **kwargs) self.lamb = kwargs['lamb'] if 'lamb' in kwargs else 10000.
[docs] def afterInference(self, results, model, optimizer, _curr_batch, training_states): """ The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function. EWC performs regularization process in this function. Args: results (dict): the returned dictionary from the event function `inference`. model (torch.nn.Module): the current trained model. optimizer (torch.optim.Optimizer): the current optimizer function. curr_batch (object): the data (or minibatch) for the current iteration. curr_training_states (dict): the dictionary containing the current training states. Returns: A dictionary containing the information from the `results`. """ loss_reg = 0 for _param, _fisher in zip(training_states['params'], training_states['fishers']): for name, p in model.named_parameters(): l = self.lamb * _fisher[name] l = l * ((p - _param[name]) ** 2) loss_reg = loss_reg + l.sum() total_loss = results['loss'] + loss_reg total_loss.backward() optimizer.step() return {'loss': total_loss.item(), 'acc': self.eval_fn(torch.argmax(results['preds'], -1), _curr_batch[-1].to(self.device))}
[docs] def initTrainingStates(self, scenario, model, optimizer): return {'fishers': [], 'params': []}
[docs] def processAfterTraining(self, task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states): """ The event function to execute some processes after training the current task. EWC computes fisher information matrix and stores the learned weights to compute the penalty term in :func:`afterInference` Args: task_id (int): the index of the current task. curr_dataset (object): The dataset for the current task. curr_model (torch.nn.Module): the current trained model. curr_optimizer (torch.optim.Optimizer): the current optimizer function. curr_training_states (dict): the dictionary containing the current training states. """ super().processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states) params = {name: torch.zeros_like(p) for name, p in curr_model.named_parameters()} fishers = {name: torch.zeros_like(p) for name, p in curr_model.named_parameters()} train_loader = self.prepareLoader(curr_dataset, curr_training_states)[0] total_num_items = 0 for i, _curr_batch in enumerate(iter(train_loader)): curr_model.zero_grad() curr_results = self.inference(curr_model, _curr_batch, curr_training_states) curr_results['loss'].backward() curr_num_items =_curr_batch[1].shape[0] total_num_items += curr_num_items for name, p in curr_model.named_parameters(): params[name] = p.data.clone().detach() fishers[name] += (p.grad.data.clone().detach() ** 2) * curr_num_items for name, p in curr_model.named_parameters(): fishers[name] /= total_num_items curr_training_states['fishers'].append(fishers) curr_training_states['params'].append(params)
[docs]class LCTimeILEWCTrainer(LCTrainer): def __init__(self, model, scenario, optimizer_fn, loss_fn, device, **kwargs): """ EWC needs `lamb`, the additional hyperparamter for the regularization term used in :func:`afterInference` """ super().__init__(model.to(device), scenario, optimizer_fn, loss_fn, device, **kwargs) self.lamb = kwargs['lamb'] if 'lamb' in kwargs else 10000.
[docs] def processBeforeTraining(self, task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states): """ The event function to execute some processes before training. We need to extend the base function since the output format is slightly different from the base trainer. Args: task_id (int): the index of the current task curr_dataset (object): The dataset for the current task. curr_model (torch.nn.Module): the current trained model. curr_optimizer (torch.optim.Optimizer): the current optimizer function. curr_training_states (dict): the dictionary containing the current training states. """ curr_training_states['scheduler'] = self.scheduler_fn(curr_optimizer) curr_training_states['best_val_acc'] = -1. curr_training_states['best_val_loss'] = 1e10 curr_model.observe_labels(torch.LongTensor([0])) self._reset_optimizer(curr_optimizer)
[docs] def processEvalIteration(self, model, _curr_batch): """ The event function to execute some processes before training. We need to extend the base function since the output format is slightly different from the base trainer. Args: task_id (int): the index of the current task curr_dataset (object): The dataset for the current task. curr_model (torch.nn.Module): the current trained model. curr_optimizer (torch.optim.Optimizer): the current optimizer function. curr_training_states (dict): the dictionary containing the current training states. """ results = self.inference(model, _curr_batch, None) return results['preds'], {'loss': results['loss'].item()}
[docs] def afterInference(self, results, model, optimizer, _curr_batch, training_states): """ The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function. EWC performs regularization process in this function. Args: results (dict): the returned dictionary from the event function `inference`. model (torch.nn.Module): the current trained model. optimizer (torch.optim.Optimizer): the current optimizer function. curr_batch (object): the data (or minibatch) for the current iteration. curr_training_states (dict): the dictionary containing the current training states. Returns: A dictionary containing the information from the `results`. """ loss_reg = 0 for _param, _fisher in zip(training_states['params'], training_states['fishers']): for name, p in model.named_parameters(): l = self.lamb * _fisher[name] l = l * ((p - _param[name]) ** 2) loss_reg = loss_reg + l.sum() total_loss = results['loss'] + loss_reg total_loss.backward() optimizer.step() return {'loss': total_loss.item(), 'acc': self.eval_fn(results['preds'], _curr_batch[-1].to(self.device))}
[docs] def initTrainingStates(self, scenario, model, optimizer): return {'fishers': [], 'params': []}
[docs] def processAfterTraining(self, task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states): """ The event function to execute some processes after training the current task. EWC computes fisher information matrix and stores the learned weights to compute the penalty term in :func:`afterInference`. Args: task_id (int): the index of the current task. curr_dataset (object): The dataset for the current task. curr_model (torch.nn.Module): the current trained model. curr_optimizer (torch.optim.Optimizer): the current optimizer function. curr_training_states (dict): the dictionary containing the current training states. """ super().processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states) params = {name: torch.zeros_like(p) for name, p in curr_model.named_parameters()} fishers = {name: torch.zeros_like(p) for name, p in curr_model.named_parameters()} train_loader = self.prepareLoader(curr_dataset, curr_training_states)[0] total_num_items = 0 for i, _curr_batch in enumerate(iter(train_loader)): curr_model.zero_grad() curr_results = self.inference(curr_model, _curr_batch, curr_training_states) curr_results['loss'].backward() curr_num_items =_curr_batch[1].shape[0] total_num_items += curr_num_items for name, p in curr_model.named_parameters(): params[name] = p.data.clone().detach() fishers[name] += (p.grad.data.clone().detach() ** 2) * curr_num_items for name, p in curr_model.named_parameters(): fishers[name] /= total_num_items curr_training_states['fishers'].append(fishers) curr_training_states['params'].append(params)
[docs]class LPTimeILEWCTrainer(LPTrainer): def __init__(self, model, scenario, optimizer_fn, loss_fn, device, **kwargs): """ EWC needs `lamb`, the additional hyperparamter for the regularization term used in :func:`processTrainIteration.` """ super().__init__(model.to(device), scenario, optimizer_fn, loss_fn, device, **kwargs) self.lamb = kwargs['lamb'] if 'lamb' in kwargs else 1. self.T = kwargs['T'] if 'T' in kwargs else 2.
[docs] def processTrainIteration(self, model, optimizer, _curr_batch, training_states): """ The event function to handle every training iteration. EWC performs inference and regularization process in this function. Args: model (torch.nn.Module): the current trained model. optimizer (torch.optim.Optimizer): the current optimizer function. curr_batch (object): the data (or minibatch) for the current iteration. curr_training_states (dict): the dictionary containing the current training states. Returns: A dictionary containing the outcomes (stats) during the training iteration. """ graph, feats = map(lambda x: x.to(self.device), training_states['graph']) edges, labels = map(lambda x: x.to(self.device), _curr_batch) optimizer.zero_grad() srcs, dsts = edges[:, 0], edges[:, 1] neg_dsts = torch.randint(low=0, high=graph.num_nodes(), size=(srcs.shape[0],)).to(self.device) preds = model(graph, feats, srcs.repeat(2), torch.cat((edges[:, 1], neg_dsts), dim=0)).squeeze(-1) loss = self.loss_fn(preds, torch.cat((labels, torch.zeros_like(labels)), dim=0)) loss_reg = 0 for _param, _fisher in zip(training_states['params'], training_states['fishers']): for name, p in model.named_parameters(): l = self.lamb * _fisher[name] l = l * ((p - _param[name]) ** 2) loss_reg = loss_reg + l.sum() loss = loss + loss_reg loss.backward() optimizer.step() return {'_num_items': preds.shape[0], 'loss': loss.item()}
[docs] def initTrainingStates(self, scenario, model, optimizer): return {'fishers': [], 'params': []}
[docs] def processAfterTraining(self, task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states): """ The event function to execute some processes after training the current task. EWC computes fisher information matrix and stores the learned weights to compute the penalty term in :func:`processTrainIteration`. Args: task_id (int): the index of the current task. curr_dataset (object): The dataset for the current task. curr_model (torch.nn.Module): the current trained model. curr_optimizer (torch.optim.Optimizer): the current optimizer function. curr_training_states (dict): the dictionary containing the current training states. """ super().processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states) params = {} fishers = {} train_loader = self.prepareLoader(curr_dataset, curr_training_states)[0] total_num_items = 0 for i, _curr_batch in enumerate(iter(train_loader)): curr_model.zero_grad() graph, feats = map(lambda x: x.to(self.device), curr_training_states['graph']) edges, labels = map(lambda x: x.to(self.device), _curr_batch) srcs, dsts = edges[:, 0], edges[:, 1] neg_dsts = torch.randint(low=0, high=graph.num_nodes(), size=(srcs.shape[0],)).to(self.device) preds = curr_model(graph, feats, srcs.repeat(2), torch.cat((edges[:, 1], neg_dsts), dim=0)).squeeze(-1) loss = self.loss_fn(preds, torch.cat((labels, torch.zeros_like(labels)), dim=0)) loss.backward() total_num_items += labels.shape[0] if i == 0: for name, p in curr_model.named_parameters(): params[name] = p.data.clone().detach() fishers[name] = ((p.grad.data.clone().detach() ** 2) * labels.shape[0]) else: for name, p in curr_model.named_parameters(): fishers[name] += ((p.grad.data.clone().detach() ** 2) * labels.shape[0]) for name, p in curr_model.named_parameters(): fishers[name] /= total_num_items curr_training_states['fishers'].append(fishers) curr_training_states['params'].append(params)
[docs]class LPDomainILEWCTrainer(LPTimeILEWCTrainer): """ This trainer has the same behavior as `LPTimeILEWCTrainer`. """ pass