ContinualGNN

ContinualGNN is a continual learning mehtod for graphs which combines replay- and regularization-based approaches. For the details, see the original paper.

Node-level Problems

class NCClassILCGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

prepareLoader(curr_dataset, curr_training_states)[source]

The event function to generate dataloaders from the given dataset for the current task.

ContinualGNN additionally puts saved nodes from the previous tasks to the training set of the current task.

Parameters:
  • curr_dataset (object) – The dataset for the current task. Its type is dgl.graph for node-level and link-level problem, and dgl.data.DGLDataset for graph-level problem.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A tuple containing three dataloaders. The trainer considers the first dataloader, second dataloader, and third dataloader as dataloaders for training, validation, and test, respectively.

processAfterEachIteration(curr_model, curr_optimizer, curr_training_states, curr_iter_results)[source]

The event function to execute some processes for every end of each epoch. Whether to continue training or not is determined by the return value of this function. If the returned value is False, the trainer stops training the current model in the current task.

Note

This function is called for every end of each epoch, and the event function processAfterTraining is called only when the learning on the current task has ended.

Parameters:
  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

  • curr_iter_results (dict) – the dictionary containing the training/validation results of the current epoch.

Returns:

A boolean value. If the returned value is False, the trainer stops training the current model in the current task.

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

ContinualGNN updates the buffer using the memory handler.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

processBeforeTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes before training.

In this function, ContinualGNN chooses important nodes for regularization.

Parameters:
  • task_id (int) – the index of the current task

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

processTrainIteration(model, optimizer, _curr_batch, training_states)[source]

The event function to handle every training iteration.

ContinualGNN performs regularization process (based on EWC) in this function.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the outcomes (stats) during the training iteration.

class NCDomainILCGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
class NCTaskILCGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

prepareLoader(curr_dataset, curr_training_states)[source]

The event function to generate dataloaders from the given dataset for the current task.

ContinualGNN additionally puts saved nodes from the previous tasks to the training set of the current task.

Parameters:
  • curr_dataset (object) – The dataset for the current task. Its type is dgl.graph for node-level and link-level problem, and dgl.data.DGLDataset for graph-level problem.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A tuple containing three dataloaders. The trainer considers the first dataloader, second dataloader, and third dataloader as dataloaders for training, validation, and test, respectively.

processAfterEachIteration(curr_model, curr_optimizer, curr_training_states, curr_iter_results)[source]

The event function to execute some processes for every end of each epoch. Whether to continue training or not is determined by the return value of this function. If the returned value is False, the trainer stops training the current model in the current task.

Note

This function is called for every end of each epoch, and the event function processAfterTraining is called only when the learning on the current task has ended.

Parameters:
  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

  • curr_iter_results (dict) – the dictionary containing the training/validation results of the current epoch.

Returns:

A boolean value. If the returned value is False, the trainer stops training the current model in the current task.

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

ContinualGNN updates the buffer using the memory handler.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

processBeforeTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes before training.

In this function, ContinualGNN chooses important nodes for regularization.

Parameters:
  • task_id (int) – the index of the current task

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

processEvalIteration(model, _curr_batch)[source]

The event function to handle every evaluation iteration.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

Returns:

A dictionary containing the outcomes (stats) during the evaluation iteration.

processTrainIteration(model, optimizer, _curr_batch, training_states)[source]

The event function to handle every training iteration.

ContinualGNN performs regularization process (based on EWC) in this function.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the outcomes (stats) during the training iteration.

class NCTimeILCGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as NCClassILCGNNTrainer.