MAS
Memory Aware Synapses (MAS) is a regularization-based continual learning method that computes the importance of parameters according to how sensitive the parameters are. For the details, see the original paper.
Node-level Problems
- class NCClassILMASMinibatchTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class NCClassILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class NCDomainILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
This trainer has the same behavior as NCClassILMASTrainer.
- class NCTaskILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
The event function to execute inference step.
For task-IL, we need to additionally consider task information for the inference step.
- Parameters:
model (torch.nn.Module) – the current trained model.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the inference results, such as prediction result and loss.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
Link-level Problems
- class LCClassILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class LCTaskILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
The event function to execute inference step.
For task-IL, we need to additionally consider task information for the inference step.
- Parameters:
model (torch.nn.Module) – the current trained model.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the inference results, such as prediction result and loss.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- prepareLoader(_curr_dataset, curr_training_states)[source]
The event function to generate dataloaders from the given dataset for the current task.
For task-IL, we need to additionally consider task information for the inference step.
- Parameters:
- Returns:
A tuple containing three dataloaders. The trainer considers the first dataloader, second dataloader, and third dataloader as dataloaders for training, validation, and test, respectively.
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class LCTimeILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processBeforeTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes before training.
We need to extend the base function since the output format is slightly different from the base trainer.
- Parameters:
task_id (int) – the index of the current task
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processEvalIteration(model, _curr_batch)[source]
The event function to handle every evaluation iteration.
We need to extend the function since the output format is slightly different from the base trainer.
- Parameters:
model (torch.nn.Module) – the current trained model.
curr_batch (object) – the data (or minibatch) for the current iteration.
- Returns:
A dictionary containing the outcomes (stats) during the evaluation iteration.
- class LPDomainILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
This trainer has the same behavior as LPTimeILMASTrainer.
- class LPTimeILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
processTrainIteration().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processTrainIteration(model, optimizer, _curr_batch, training_states)[source]
The event function to handle every training iteration.
MAS performs inference and regularization process in this function.
- Parameters:
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the outcomes (stats) during the training iteration.
Graph-level Problems
- class GCClassILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class GCDomainILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
This trainer has the same behavior as GCClassILMASTrainer.
- class GCTaskILMASTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
MAS performs regularization process in this function.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
The event function to execute inference step.
For task-IL, we need to additionally consider task information for the inference step.
- Parameters:
model (torch.nn.Module) – the current trained model.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the inference results, such as prediction result and loss.
- initTrainingStates(scenario, model, optimizer)[source]
The event function to initialize the dictionary for storing training states (i.e., intermedeiate results).
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
MAS computes importances and stores the learned weights to compute the penalty term in
afterInference().- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.