Source code for begin.algorithms.bare.nodes

import sys
from begin.trainers.nodes import NCTrainer, NCMinibatchTrainer

[docs]class NCTaskILBareTrainer(NCTrainer):
[docs] def inference(self, model, _curr_batch, training_states): """ The event function to execute inference step. For task-IL, we need to additionally consider task information for the inference step. Args: model (torch.nn.Module): the current trained model. curr_batch (object): the data (or minibatch) for the current iteration. curr_training_states (dict): the dictionary containing the current training states. Returns: A dictionary containing the inference results, such as prediction result and loss. """ curr_batch, mask = _curr_batch # use task_masks as additional input preds = model(curr_batch.to(self.device), curr_batch.ndata['feat'].to(self.device), task_masks=curr_batch.ndata['task_specific_mask'].to(self.device))[mask] loss = self.loss_fn(preds, curr_batch.ndata['label'][mask].to(self.device)) return {'preds': preds, 'loss': loss}
[docs]class NCClassILBareTrainer(NCTrainer): """ This trainer has the same behavior as `NCTrainer`. """ pass
[docs]class NCClassILBareMinibatchTrainer(NCMinibatchTrainer): """ This trainer has the same behavior as `NCMinibatchTrainer`. """ pass
[docs]class NCDomainILBareTrainer(NCTrainer): """ This trainer has the same behavior as `NCTrainer`. """ pass
[docs]class NCTimeILBareTrainer(NCTrainer): """ This trainer has the same behavior as `NCTrainer`. """ pass