"""
Module with custom metrics for simplicial neural networks.
"""
import torch
from torchmetrics import Metric


class RegressionAccuracy(Metric):
    """
    Thresholded accuracy metric for regression tasks.

    The thresholded accuracy metric follows the accuracy definition as used in the
    simplicial neural network paper. A missing value is considered to be correctly
    imputed if the prediction differs by at most `threshold` percent from the target.

    Parameters
    ----------
    threshold : float
        The threshold in percent. Must be between 0 and 1.
    **kwargs : keyword arguments
        Additional keyword arguments passed to lightning's base `Metric` class.

    Raises
    ------
    ValueError
        If `threshold` is not between 0 and 1.
    """

    full_state_update = False
    threshold: float

    def __init__(self, threshold: float, **kwargs) -> None:
        """
        Thresholded accuracy metric for regression tasks.

        The thresholded accuracy metric follows the accuracy definition as used in the
        simplicial neural network paper. A missing value is considered to be correctly
        imputed if the prediction differs by at most `threshold` percent from the
        target.

        Parameters
        ----------
        threshold : float
            The threshold in percent. Must be between 0 and 1.
        **kwargs : keyword arguments
            Additional keyword arguments passed to lightning's base `Metric` class.

        Raises
        ------
        ValueError
            If `threshold` is not between 0 and 1.
        """
        super().__init__(**kwargs)

        if threshold < 0.0 or threshold > 1.0:
            raise ValueError("threshold must be between 0 and 1")

        self.threshold = threshold

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        """
        Update metric state with new predictions.

        Parameters
        ----------
        preds : torch.Tensor
            Predictions.
        target : torch.Tensor
            Targets.
        """
        if preds.shape != target.shape:
            raise ValueError(
                f"`preds` and `target` must have the same shape, but got {preds.shape} and {target.shape}"
            )

        absolute_difference = torch.abs(preds - target)
        relative_difference = torch.abs(absolute_difference / target)

        self.correct += torch.sum(relative_difference <= self.threshold)
        self.total += target.numel()

    def compute(self) -> torch.Tensor:
        """
        Compute the thresholded accuracy of all predictions seen so far.

        Returns
        -------
        torch.Tensor, shape = [1]
            The thresholded accuracy.
        """
        return self.correct.float() / self.total
