Source code for begin.utils.models

import torch
from torch import nn
from dgl.nn import GraphConv, SumPooling, AvgPooling, MaxPooling
import torch.nn.functional as F
from dgl.base import DGLError
from dgl.utils import expand_as_pair
import dgl.function as fn
from torch_scatter import segment_csr

[docs]class AdaptiveLinear(nn.Module): r""" The linear layer for helping the continual learning procedure, by masking the outputs. Arguments: in_channels (int): the number of input channels (in_features of `torch.nn.Linear`). out_channels (int): the number of output channels (out_features of `torch.nn.Linear`). bias (bool): If set to False, the layer will not learn an additive bias. accum (bool): If set to True, the layer can also provide the logits for the previously seen outputs. """ def __init__(self, in_channels, out_channels, bias=True, accum=True): super().__init__() self.lin = nn.Linear(in_channels, out_channels, bias) self.bias = bias self.accum = accum self.num_outputs = out_channels self.output_masks = None self.observed = torch.zeros(out_channels, dtype=torch.bool)
[docs] def observe_outputs(self, new_outputs, verbose=True): r""" Observes the ideal outputs in the training dataset. By observing the outputs, the layer determines which outputs (logits) will be masked or not. Arguments: new_outputs (torch.Tensor): the ideal outputs in the training dataset. """ device = self.lin.weight.data.device new_outputs = torch.unique(new_outputs) new_num_outputs = max(self.num_outputs, new_outputs.max() + 1) new_output_mask = torch.zeros(new_num_outputs, dtype=torch.bool).to(device) new_output_mask[new_outputs] = True prv_observed = self.observed if self.output_masks is None: self.output_masks = [new_output_mask] else: if new_num_outputs > self.num_outputs: self.output_masks = [torch.cat((output_mask, torch.zeros(new_num_outputs - self.num_outputs, dtype=torch.bool).to(device)), dim=-1) for output_mask in self.output_masks] if self.accum: self.output_masks.append(self.output_masks[-1] | new_output_mask) else: self.output_masks.append(new_output_mask) # if a new class whose index exceeds `self.num_outputs` is observed, expand the output if new_num_outputs > self.num_outputs: prev_weight, prev_bias = self.lin.weight.data[prv_observed], (self.lin.bias.data[prv_observed] if self.bias else None) self.observed = torch.cat((self.observed.to(device), torch.zeros(new_num_outputs - self.num_outputs, dtype=torch.bool).to(device)), dim=-1) self.lin = nn.Linear(in_features, new_num_outputs, bias=self.bias) self.lin.weight.data[self.observed] = prev_weight if self.bias: self.lin.bias.data[self.observed] = prev_bias self.num_outputs = new_num_outputs self.observed = self.observed.to(device) | new_output_mask
[docs] def get_output_mask(self, task_ids=None): r""" Returns the mask managed by the layer. Arguments: task_ids (torch.Tensor or None): If task_ids is provided, the layer returns the mask for each index of the task. Otherwise, it returns the mask for the last task it observed. """ if task_ids is None: # for class-IL, time-IL, and domain-IL return self.output_masks[-1] else: # for task-IL mask = torch.zeros(task_ids.shape[0], self.num_outputs).bool().to(task_ids.device) observed_mask = task_ids < len(self.output_masks) mask[observed_mask] = torch.stack(self.output_masks, dim=0)[task_ids[observed_mask]] return mask
[docs] def forward(self, x, task_masks=None): r""" Returns the masked results of the inner linear layer. Arguments: x (torch.Tensor): the input features. task_masks (torch.Tensor or None): If task_masks is provided, the layer uses the tensor for masking the outputs. Otherwise, the layer uses the mask managed by the layer. """ out = self.lin(x) if task_masks is None: out[..., ~self.observed] = -1e12 else: out[~task_masks] = -1e12 return out
class GCNNode(nn.Module): def __init__(self, in_feats, n_classes, n_hidden, activation = F.relu, dropout=0.0, n_layers=3, incr_type='class', use_classifier=True): super().__init__() self.n_layers = n_layers self.n_hidden = n_hidden self.n_classes = n_classes self.convs = nn.ModuleList() self.norms = nn.ModuleList() for i in range(n_layers): in_hidden = n_hidden if i > 0 else in_feats out_hidden = n_hidden self.convs.append(GraphConv(in_hidden, out_hidden, "both", bias=False, allow_zero_in_degree=True)) self.norms.append(nn.BatchNorm1d(out_hidden)) self.dropout = nn.Dropout(dropout) self.activation = activation if use_classifier: self.classifier = AdaptiveLinear(n_hidden, n_classes, bias=True, accum = False if incr_type == 'task' else True) else: self.classifier = None def forward(self, graph, feat, task_masks=None): h = feat h = self.dropout(h) for i in range(self.n_layers): conv = self.convs[i](graph, h) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) if self.classifier is not None: h = self.classifier(h, task_masks) return h def forward_hat(self, graph, feat, hat_masks, task_masks=None): h = feat h = self.dropout(h) for i in range(self.n_layers): conv = self.convs[i](graph, h) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) device = h.get_device() h=h*hat_masks[i].to(device).expand_as(h) if self.classifier is not None: h = self.classifier(h, task_masks) return h def forward_without_classifier(self, graph, feat, task_masks=None): h = feat h = self.dropout(h) for i in range(self.n_layers): conv = self.convs[i](graph, h) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) return h def bforward(self, blocks, feat, task_masks=None): h = feat h = self.dropout(h) for i in range(self.n_layers): conv = self.convs[i](blocks[i], h) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) if self.classifier is not None: h = self.classifier(h, task_masks) return h def observe_labels(self, new_labels, verbose=True): self.classifier.observe_outputs(new_labels, verbose=verbose) def get_observed_labels(self, tid=None): if tid is None or tid < 0: return self.classifier.observed else: return self.classifier.output_masks[tid] class GCNLink(nn.Module): def __init__(self, in_feats, n_classes, n_hidden, activation = F.relu, dropout=0.0, n_layers=3, incr_type='class'): super().__init__() self.gcn = GCNNode(in_feats, n_hidden, n_hidden, activation = F.relu, dropout=0.0, n_layers=3, incr_type='class', use_classifier=False) self.n_layers = n_layers self.n_hidden = n_hidden self.n_classes = n_classes self.linears = nn.ModuleList() for i in range(n_layers - 1): in_hidden = n_hidden out_hidden = n_hidden self.linears.append(nn.Linear(in_hidden, out_hidden)) self.dropout = nn.Dropout(dropout) self.activation = activation self.classifier = AdaptiveLinear(n_hidden, n_classes, bias=True, accum = False if incr_type == 'task' else True) def forward(self, graph, feat, srcs, dsts, task_masks=None): _h = self.gcn(graph, feat, task_masks) x = _h[srcs] * _h[dsts] x = self.dropout(x) for i in range(self.n_layers - 1): x = self.linears[i](x) x = self.activation(x) x = self.dropout(x) x = self.classifier(x, task_masks) return x def forward_hat(self, graph, feat, srcs, dsts, hat_masks=None, task_masks=None): _h = self.gcn.forward_hat(graph, feat, hat_masks[:-(self.n_layers - 1)], task_masks) x = _h[srcs] * _h[dsts] x = self.dropout(x) for i in range(self.n_layers - 1): x = self.linears[i](x) x = self.activation(x) x = self.dropout(x) x = self.classifier(x, task_masks) return x def forward_without_classifier(self, graph, feat, srcs, dsts, task_masks=None): _h = self.gcn(graph, feat, task_masks) x = _h[srcs] * _h[dsts] x = self.dropout(x) for i in range(self.n_layers - 1): x = self.linears[i](x) x = self.activation(x) x = self.dropout(x) return x def observe_labels(self, new_labels, verbose=True): self.classifier.observe_outputs(new_labels, verbose=verbose) def get_observed_labels(self, tid=None): if tid is None or tid < 0: return self.classifier.observed else: return self.classifier.output_masks[tid] class GCNConv(nn.Module): def __init__(self, in_feats, out_feats, norm='both', weight=True, bias=True, activation=None, allow_zero_in_degree=False, edge_encoder_fn=None): super(GCNConv, self).__init__() if norm not in ('none', 'both', 'right'): raise DGLError('Invalid norm value. Must be either "none", "both" or "right".' ' But got "{}".'.format(norm)) self._in_feats = in_feats self._out_feats = out_feats self._norm = norm self._allow_zero_in_degree = allow_zero_in_degree if weight: self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) else: self.register_parameter('weight', None) if bias: self.bias = nn.Parameter(torch.Tensor(out_feats)) else: self.register_parameter('bias', None) self.reset_parameters() self._activation = activation if edge_encoder_fn is not None: self.edge_encoder = edge_encoder_fn() def reset_parameters(self): if self.weight is not None: nn.init.xavier_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def set_allow_zero_in_degree(self, set_value): self._allow_zero_in_degree = set_value def forward(self, graph, feat, weight=None, edge_weight=None, edge_attr=None): with graph.local_scope(): if not self._allow_zero_in_degree: if (graph.in_degrees() == 0).any(): raise DGLError('There are 0-in-degree nodes in the graph, ' 'output for those nodes will be invalid. ' 'This is harmful for some applications, ' 'causing silent performance regression. ' 'Adding self-loop on the input graph by ' 'calling `g = dgl.add_self_loop(g)` will resolve ' 'the issue. Setting ``allow_zero_in_degree`` ' 'to be `True` when constructing this module will ' 'suppress the check and let the code run.') aggregate_fn = fn.copy_u('h', 'm') if edge_weight is not None: # print("EWI") assert edge_weight.shape[0] == graph.number_of_edges() graph.edata['_edge_weight'] = edge_weight aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') elif edge_attr is not None: assert edge_attr.shape[0] == graph.number_of_edges() graph.edata['_edge_attr'] = self.edge_encoder(edge_attr) aggregate_fn = fn.u_add_e('h', '_edge_attr', 'm') # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite. feat_src, feat_dst = expand_as_pair(feat, graph) if self._norm == 'both': degs = graph.out_degrees().float().clamp(min=1) norm = torch.pow(degs, -0.5) shp = norm.shape + (1,) * (feat_src.dim() - 1) norm = torch.reshape(norm, shp) feat_src = feat_src * norm if weight is not None: if self.weight is not None: raise DGLError('External weight is provided while at the same time the' ' module has defined its own weight parameter. Please' ' create the module with flag weight=False.') else: weight = self.weight if self._in_feats > self._out_feats: # mult W first to reduce the feature size for aggregation. if weight is not None: feat_src = torch.matmul(feat_src, weight) graph.srcdata['h'] = feat_src graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) rst = graph.dstdata['h'] else: # aggregate first then mult W graph.srcdata['h'] = feat_src graph.update_all(aggregate_fn, fn.sum(msg='m', out='h')) rst = graph.dstdata['h'] if weight is not None: rst = torch.matmul(rst, weight) if self._norm != 'none': degs = graph.in_degrees().float().clamp(min=1) if self._norm == 'both': norm = torch.pow(degs, -0.5) else: norm = 1.0 / degs shp = norm.shape + (1,) * (feat_dst.dim() - 1) norm = torch.reshape(norm, shp) rst = rst * norm if self.bias is not None: rst = rst + self.bias if self._activation is not None: rst = self._activation(rst) return rst def extra_repr(self): summary = 'in={_in_feats}, out={_out_feats}' summary += ', normalization={_norm}' if '_activation' in self.__dict__: summary += ', activation={_activation}' return summary.format(**self.__dict__) class GCNGraph(nn.Module): def __init__(self, in_feats, n_classes, n_hidden, activation = F.relu, dropout=0.0, n_layers=4, n_mlp_layers=2, incr_type='class', readout='mean', node_encoder_fn=None, edge_encoder_fn=None): super().__init__() self.n_layers = n_layers self.n_mlp_layers = n_mlp_layers self.n_hidden = n_hidden self.n_classes = n_classes self.convs = nn.ModuleList() self.norms = nn.ModuleList() if node_encoder_fn is None: self.enc = nn.Linear(in_feats, n_hidden) else: self.enc = node_encoder_fn() for i in range(n_layers): in_hidden = n_hidden # if i > 0 else in_feats out_hidden = n_hidden self.convs.append(GCNConv(in_hidden, out_hidden, "both", bias=False, allow_zero_in_degree=True, edge_encoder_fn=edge_encoder_fn)) self.norms.append(nn.BatchNorm1d(out_hidden)) self.dropout = nn.Dropout(dropout) self.activation = activation self.mlp_layers = nn.ModuleList([nn.Linear(n_hidden // (1 << i), n_hidden // (1 << (i+1))) for i in range(n_mlp_layers)]) self.classifier = AdaptiveLinear(n_hidden // (1 << n_mlp_layers), n_classes, bias=True, accum = False if incr_type == 'task' else True) self.readout_mode = readout def forward(self, graph, feat, task_masks=None, edge_weight=None, edge_attr=None, get_intermediate_outputs=False): h = self.enc(feat) h = self.dropout(h) inter_hs = [] for i in range(self.n_layers): conv = self.convs[i](graph, h, edge_weight=edge_weight, edge_attr=edge_attr) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) inter_hs.append(h) if self.readout_mode != 'none': # h0 = self.readout_fn(graph, h) # use deterministic algorithm instead ptrs = torch.cat((torch.LongTensor([0]).to(h.device), torch.cumsum(graph.batch_num_nodes(), dim=-1)), dim=-1) h1 = segment_csr(h, ptrs, reduce=self.readout_mode) # print((h1 - h0).abs().sum()) => 0 h = h1 for layer in self.mlp_layers: h = layer(h) h = self.activation(h) h = self.dropout(h) inter_hs.append(h) h = self.classifier(h, task_masks) if get_intermediate_outputs: return h, inter_hs else: return h def forward_hat(self, graph, feat, hat_masks, task_masks=None, edge_weight=None, edge_attr=None): h = self.enc(feat) h = self.dropout(h) device = h.get_device() h = h * hat_masks[0].to(device).expand_as(h) for i in range(self.n_layers): conv = self.convs[i](graph, h, edge_weight=edge_weight, edge_attr=edge_attr) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) h = h * hat_masks[1 + i].to(device).expand_as(h) if self.readout_mode != 'none': ptrs = torch.cat((torch.LongTensor([0]).to(h.device), torch.cumsum(graph.batch_num_nodes(), dim=-1)), dim=-1) h1 = segment_csr(h, ptrs, reduce=self.readout_mode) h = h1 for layer in self.mlp_layers: h = layer(h) h = self.activation(h) h = self.dropout(h) h=h*hat_masks[1 + self.n_layers + i].to(device).expand_as(h) h = self.classifier(h, task_masks) return h def forward_without_classifier(self, graph, feat, task_masks=None, edge_weight=None, edge_attr=None): h = self.enc(feat) h = self.dropout(h) for i in range(self.n_layers): conv = self.convs[i](graph, h, edge_weight=edge_weight, edge_attr=edge_attr) h = conv h = self.norms[i](h) h = self.activation(h) h = self.dropout(h) return h def observe_labels(self, new_labels, verbose=True): self.classifier.observe_outputs(new_labels, verbose=verbose) def get_observed_labels(self, tid=None): if tid is None or tid < 0: return self.classifier.observed else: return self.classifier.output_masks[tid]