##########################################################################################################
#
# FROM https://github.com/cvignac/DiGress/blob/main/dgd/metrics/abstract_metrics.py
#
##########################################################################################################

import torch
from torch import Tensor
from torch.nn import functional as F
from torchmetrics import Metric, MeanSquaredError


class SumExceptBatchMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, values) -> None:
        self.total_value += torch.sum(values)
        self.total_samples += values.shape[0]

    def compute(self):
        return self.total_value / self.total_samples
    


def regression_accuracy(pred: Tensor, true: Tensor, reduction: str = 'mean') -> Tensor:
    """ Compute accuracy for regression task. """
    # compute error
    error = torch.abs(pred - true)
    # compute accuracy
    correct = (error < 0.5).float()
    # reduce
    if reduction == 'mean':
        accuracy = torch.mean(correct)
    elif reduction == 'sum':
        accuracy = torch.sum(correct)
    elif reduction == 'none':
        accuracy = correct
    return accuracy

def classification_accuracy(pred_logits: Tensor, true: Tensor, reduction: str = 'mean') -> Tensor:
    """ Compute accuracy for classification task. """
    # compute accuracy
    pred = torch.argmax(pred_logits, dim=-1)
    correct = (pred == true).float()
    # reduce
    if reduction == 'mean':
        accuracy = torch.mean(correct) if correct.numel() > 0 else torch.full((1,), 0.5, device=pred_logits.device)
    elif reduction == 'sum':
        accuracy = torch.sum(correct)
    elif reduction == 'none':
        accuracy = correct
    return accuracy

def binary_classification_accuracy(pred_logits: Tensor, true: Tensor, reduction: str = 'mean') -> Tensor:
    """ Compute accuracy for binary classification task. Inputs have final dimension 1. """
    # compute accuracy
    pred = (pred_logits > 0).float()
    correct = (pred == true).float()
    # reduce
    if reduction == 'mean':
        accuracy = torch.mean(correct) if correct.numel() > 0 else torch.full((1,), 0.5, device=pred_logits.device)
    elif reduction == 'sum':
        accuracy = torch.sum(correct)
    elif reduction == 'none':
        accuracy = correct
    return accuracy

def binary_classification_recall(pred_logits: Tensor, true: Tensor, reduction: str = 'mean') -> Tensor:
    """ Compute recall for binary classification task. Inputs have final dimension 1. """

    # compute mask
    mask = true > 0.5

    # compute accuracy
    pred = (pred_logits > 0).float()
    positive = (pred == true)[mask].float()
    # reduce
    if reduction == 'mean':
        recall = torch.mean(positive) if positive.numel() > 0 else torch.full((1,), 0.5, device=pred_logits.device)
    elif reduction == 'sum':
        recall = torch.sum(positive)
    elif reduction == 'none':
        recall = positive
    return recall