import torch
from torch import nn
from torch_scatter import scatter
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import time
[docs]class BaseEvaluator:
r"""
Base class for evaluating the performance.
Users can create their own evaluator by extending this class.
Arguments:
num_tasks (int): The number of tasks in the target scenario.
task_ids (torch.Tensor): task ids of each instance.
"""
def __init__(self, num_tasks, task_ids):
self.num_tasks = num_tasks
self._task_ids = task_ids
def __call__(self, prediction, answer, indices):
r"""
Measure the performance on each task.
Args:
prediction (torch.Tensor): predicted output of the current model
answer (torch.Tensor): ground-truth answer
indices (torch.Tensor): indexes of the chosen instances for evaluation
"""
raise NotImplementedError
[docs] def simple_eval(self, prediction, answer):
r"""
Compute performance for the given batch when we ignore task configuration.
During the training procedure, this function is called by the function get_simple_eval_result implemented in ScenarioLoaders.
Args:
prediction (torch.Tensor): predicted output of the current model
answer (torch.Tensor): ground-truth answer
"""
raise NotImplementedError
[docs]class AccuracyEvaluator(BaseEvaluator):
r"""
The evaluator for computing accuracy.
Bases: ``BaseEvaluator``
"""
def __call__(self, _prediction, _answer, indices):
prediction = _prediction.squeeze().to(_answer.device)
answer = _answer.squeeze()
scope = self._task_ids[indices] < self.num_tasks
accuracy_per_task = scatter((prediction == answer).float(), self._task_ids[indices], dim=-1, reduce='mean', dim_size = self.num_tasks + 1)
accuracy_per_task[self.num_tasks] = self.simple_eval(prediction[scope], answer[scope])
return accuracy_per_task
[docs] def simple_eval(self, prediction, answer):
return ((prediction.squeeze().to(answer.device) == answer.squeeze()).float().sum() / answer.shape[0]).item()
[docs]class ROCAUCEvaluator(BaseEvaluator):
r"""
The evaluator for computing ROCAUC score.
Bases: ``BaseEvaluator``
"""
def __call__(self, _prediction, answer, indices):
prediction = _prediction.to(answer.device)
target_ids = self._task_ids[indices]
retval = torch.zeros(self.num_tasks + 1)
for i in range(self.num_tasks):
is_target = (target_ids == i)
if is_target.any(): retval[i] = self.simple_eval(prediction[is_target], answer[is_target])
retval[self.num_tasks] = self.simple_eval(prediction, answer)
return retval
[docs] def simple_eval(self, prediction, answer):
num_items, num_qs = answer.shape
valid_cols = (answer.sum(0) < num_items) & (answer.sum(0) >= 0)
prediction, answer = prediction[..., valid_cols], answer[..., valid_cols]
num_valid_qs = valid_cols.long().sum().item()
idx_order = torch.argsort(prediction, dim=0).to(answer.device).view(-1) * num_valid_qs + torch.arange(num_valid_qs).to(answer.device).repeat(num_items)
ordered_answer = (answer.view(-1)[idx_order]).view(num_items, num_valid_qs)
num_pos = torch.cumsum(ordered_answer, dim=0)
num_neg = torch.arange(num_items).unsqueeze(-1).to(answer.device) - num_pos
rocauc_scores = 1. - ((num_pos * (ordered_answer == 0)).sum(0) / (num_pos[-1] * num_neg[-1] + 1e-6))
return rocauc_scores.mean().item()
[docs]class HitsEvaluator(BaseEvaluator):
r"""
The evaluator for computing Hits@K. This module inputs K, instead of task_ids as the second parameter.
Bases: ``BaseEvaluator``
"""
def __init__(self, num_tasks, k):
super().__init__(num_tasks, None)
self.k = k
def __call__(self, _prediction, _answer, task_ids):
prediction = _prediction.squeeze().to(_answer.device)
answer = _answer.squeeze()
neg_samples = prediction[answer == 0]
if neg_samples.shape[0] < self.k: return torch.ones(self.num_tasks + 1)
neg_threshold = torch.topk(neg_samples, self.k).values[-1]
num_pos = torch.bincount(task_ids[answer == 1], minlength=self.num_tasks).float()
num_hits = torch.bincount(task_ids[(answer == 1) & (prediction > neg_threshold)], minlength=self.num_tasks).float()
hits_per_task = torch.zeros(self.num_tasks + 1)
hits_per_task[:self.num_tasks] = num_hits / torch.clamp(num_pos, min=1.)
hits_per_task[self.num_tasks] = num_hits.sum() / num_pos.sum()
return hits_per_task
[docs] def simple_eval(self, prediction, answer):
prediction = _prediction.squeeze().to(_answer.device)
answer = _answer.squeeze()
neg_samples = prediction[answer == 0]
if neg_samples.shape[0] < self.k: return torch.ones(self.num_tasks + 1)
neg_threshold = torch.topk(neg_samples, self.k).values[-1]
num_pos = (answer == 1).float().sum()
num_hits = ((answer == 1) & (prediction > neg_threshold)).float().sum()
return (num_hits / num_pos).item()