ER-GNN
ER-GNN is a continual learning method for graphs, which carefully samples nodes with topological information and uses them for re-training. For the details, see the original paper.
Node-level Problems
- class NCClassILERGNNMinibatchTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
ERGNN additionally computes the loss from the buffered nodes and applies it to backpropagation.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
The event function to execute inference step.
- Parameters:
model (torch.nn.Module) – the current trained model.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the inference results, such as prediction result and loss.
- initTrainingStates(scenario, model, optimizer)[source]
ERGNN requires node sampler. We use CM sampler as the default sampler.
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
ERGNN selects nodes using the sampler and stores them in the buffer.
- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processBeforeTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes before training.
- Parameters:
task_id (int) – the index of the current task
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class NCClassILERGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
ERGNN additionally computes the loss from the buffered nodes and applies it to backpropagation.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
The event function to execute inference step.
- Parameters:
model (torch.nn.Module) – the current trained model.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the inference results, such as prediction result and loss.
- initTrainingStates(scenario, model, optimizer)[source]
ERGNN requires node sampler. We use CM sampler as the default sampler.
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
ERGNN selects nodes using the sampler and stores them in the buffer.
- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class NCTaskILERGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- afterInference(results, model, optimizer, _curr_batch, training_states)[source]
The event function to execute some processes right after the inference step (for training). We recommend performing backpropagation in this event function.
ERGNN additionally computes the loss from the buffered nodes and applies it to backpropagation.
- Parameters:
results (dict) – the returned dictionary from the event function inference.
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the information from the results.
- inference(model, _curr_batch, training_states)[source]
ERGNN requires node sampler. We use CM sampler as the default sampler.
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- initTrainingStates(scenario, model, optimizer)[source]
ERGNN requires node sampler. We use CM sampler as the default sampler.
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
ERGNN selects nodes using the sampler and stores them in the buffer.
- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- class NCTimeILERGNNTrainer(model, scenario, optimizer_fn, loss_fn, device, **kwargs)[source]
- initTrainingStates(scenario, model, optimizer)[source]
ERGNN requires node sampler. We use CM sampler as the default sampler.
- Parameters:
scenario (begin.scenarios.common.BaseScenarioLoader) – the given ScenarioLoader to the trainer
model (torch.nn.Module) – the given model to the trainer
optmizer (torch.optim.Optimizer) – the optimizer generated from the given optimizer_fn
- Returns:
Initialized training state (dict).
- processAfterTraining(task_id, curr_dataset, curr_model, curr_optimizer, curr_training_states)[source]
The event function to execute some processes after training the current task.
ERGNN selects nodes using the sampler and stores them in the buffer.
- Parameters:
task_id (int) – the index of the current task.
curr_dataset (object) – The dataset for the current task.
curr_model (torch.nn.Module) – the current trained model.
curr_optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_training_states (dict) – the dictionary containing the current training states.
- processTrainIteration(model, optimizer, _curr_batch, training_states)[source]
The event function to handle every training iteration.
ERGNN additionally computes the loss from the buffered nodes and applies it to backpropagation.
- Parameters:
model (torch.nn.Module) – the current trained model.
optimizer (torch.optim.Optimizer) – the current optimizer function.
curr_batch (object) – the data (or minibatch) for the current iteration.
curr_training_states (dict) – the dictionary containing the current training states.
- Returns:
A dictionary containing the outcomes (stats) during the training iteration.