from regex import F
import torch
from metrics.metric import Metric


class TestLossMetric(Metric):

    def __init__(self, criterion, train=False):
        self.criterion  = criterion
        self.main_metric_name = 'value'
        super().__init__(name='Loss', train=False)

    def compute_metric(self, outputs: torch.Tensor,
                       labels: torch.Tensor, top_k=(1,)):
        """Computes the precision@k for the specified values of k"""
        # reshape for next word prediction task
        if len(outputs.shape) == 3:
            outputs = outputs.reshape([-1, outputs.shape[-1]])
            labels = labels.reshape([-1])
            mask = labels != -1
            outputs = outputs[mask]
            labels = labels[mask]
            
        loss = self.criterion(outputs, labels)
        return {'value': loss.mean().item()}