import torch
import torchmetrics

from .base import FairnessLoss
from ..statistic import Statistic


class ViolationLoss(FairnessLoss):
    def __init__(self, stat: Statistic):
        super().__init__()
        self.stat = stat

    def quantify_violation(self, violation):
        raise NotImplementedError

    def forward(self, logit, feat, sens, label, as_logit=True, metrics=None, **kwargs):
        if as_logit:
            pred = torch.sigmoid(logit)
        else:
            pred = logit

        c = self.stat.overall_statistic(pred, feat, label)
        if metrics is not None:
            metrics["c"].update(c.item())

        stats = self.stat(pred, feat, sens, label)
        if c.item() == 0.:
            loss = stats.sum()
            return loss
        violation = stats / c - 1

        if metrics is not None:
            metrics["min_stat"].update(stats.min().item())
            metrics["max_stat"].update(stats.max().item())
            metrics["min_violation"].update(violation.min().item())
            metrics["max_violation"].update(violation.max().item())

        violation = torch.abs(violation)
        loss = self.quantify_violation(violation)
        return loss

    @staticmethod
    def internal_metrics(prefix=""):
        return torchmetrics.MetricCollection({
            f"c": torchmetrics.MeanMetric(),
            f"min_stat": torchmetrics.MeanMetric(),
            f"max_stat": torchmetrics.MeanMetric(),
            f"min_violation": torchmetrics.MeanMetric(),
            f"max_violation": torchmetrics.MeanMetric(),
        }, prefix=prefix, compute_groups=False)


class NormLoss(ViolationLoss, name="norm"):
    def __init__(self, stat: Statistic, p=1, **kwargs):
        super().__init__(stat, **kwargs)
        self.p = p

    def quantify_violation(self, violation):
        return torch.linalg.vector_norm(violation, ord=self.p, dim=-1).sum()


class LSELoss(ViolationLoss, name="lse"):
    def quantify_violation(self, violation):
        loss = torch.logsumexp(violation, dim=-1) - torch.log(torch.tensor(violation.shape[-1]))
        return loss.sum()
