TWP

TWP is a continual learning method for graphs, which stabilizes parameters important in topological aggregation by graph neural networks through regularization. For the details, see the original paper.

Node-level Problems

class NCClassILTWPMinibatchTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

TWP performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states, return_elist=False)[source]

The event function to execute inference step.

TWP requires edge weights computed by attention mechanism.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

TWP computes weights for regularization process and stores the learned weights in this function.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class NCClassILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

TWP performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states, return_elist=False)[source]

The event function to execute inference step.

TWP requires edge weights computed by attention mechanism.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

TWP computes weights for regularization process and stores the learned weights in this function.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class NCDomainILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as NCClassILTWPTrainer.

class NCTaskILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

TWP performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states, return_elist=False)[source]

The event function to execute inference step.

For task-IL, we need to additionally consider task information for the inference step. TWP requires edge weights computed by attention mechanism.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

TWP computes weights for regularization process and stores the learned weights in this function.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class NCTimeILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as NCClassILTWPTrainer.

Graph-level Problems

class GCClassILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

TWP performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states, return_elist=False)[source]

The event function to execute inference step.

TWP requires edge weights computed by attention mechanism.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

TWP computes weights for regularization process and stores the learned weights in this function.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class GCDomainILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as GCClassILTWPTrainer.

class GCTaskILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

TWP performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states, return_elist=False)[source]

The event function to execute inference step.

For task-IL, we need to additionally consider task information for the inference step. TWP requires edge weights computed by attention mechanism.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

TWP computes weights for regularization process and stores the learned weights in this function.

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class GCTimeILTWPTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as GCClassILTWPTrainer.