EWC

Elastic Weight Consolidation (EWC) is a regularization-based continual learning method that weights to parameters according to the diagonal of the Fisher information matrix. For the details, see the original paper.

Node-level Problems

class NCClassILEWCMinibatchTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

EWC performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

EWC computes fisher information matrix and stores the learned weights to compute the penalty term in afterInference()

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class NCClassILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

EWC performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

EWC computes fisher information matrix and stores the learned weights to compute the penalty term in afterInference()

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class NCDomainILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as NCClassILEWCTrainer.

class NCTaskILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

EWC performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states)[source]

The event function to execute inference step.

For task-IL, we need to additionally consider task information for the inference step.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

EWC computes fisher information matrix and stores the learned weights to compute the penalty term in afterInference()

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class NCTimeILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as NCClassILEWCTrainer.

Graph-level Problems

class GCClassILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

EWC performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

EWC computes fisher information matrix and stores the learned weights to compute the penalty term in afterInference()

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class GCDomainILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as GCClassILEWCTrainer.

class GCTaskILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
afterInference(results, model, optimizer, _curr_batch, training_states)[source]

The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.

EWC performs regularization process in this function.

Parameters:
  • results (dict) – the returned dictionary from the event function inference.

  • model (torch.nn.Module) – the current trained model.

  • optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the information from the results.

inference(model, _curr_batch, training_states)[source]

The event function to execute inference step.

For task-IL, we need to additionally consider task information for the inference step.

Parameters:
  • model (torch.nn.Module) – the current trained model.

  • curr_batch (object) – the data (or minibatch) for the current iteration.

  • curr_training_states (dict) – the dictionary containing the current training states.

Returns:

A dictionary containing the inference results, such as prediction result and loss.

initTrainingStates(scenario, model, optimizer)[source]

The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).

Parameters:
Returns:

Initialized training state (dict).

processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]

The event function to execute some processes after training the current task.

EWC computes fisher information matrix and stores the learned weights to compute the penalty term in afterInference()

Parameters:
  • task_id (int) – the index of the current task.

  • curr_dataset (object) – The dataset for the current task.

  • curr_model (torch.nn.Module) – the current trained model.

  • curr_optimizer (torch.optim.Optimizer) – the current optimizer function.

  • curr_training_states (dict) – the dictionary containing the current training states.

class GCTimeILEWCTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]

This trainer has the same behavior as GCClassILEWCTrainer.