# fmt: off
from typing import List

import numpy as np
import torch
import torch.nn.functional as F
from numpy.typing import ArrayLike
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat


# fmt: on
def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08):
    """
    From https://github.com/hendrycks/anomaly-seg

    Uses high precision for cumsum and checks that the final value matches
    the sum.

    Args:
    arr : array-like
        To be cumulatively summed as flat
    rtol : float
        Relative tolerance, see ``np.allclose``
    atol : float
        Absolute tolerance, see ``np.allclose``
    """
    out = np.cumsum(arr, dtype=np.float64)
    expected = np.sum(arr, dtype=np.float64)
    if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
        raise RuntimeError(
            "cumsum was found to be unstable: "
            "its last element does not correspond to sum"
        )
    return out


class NLLMetric(Metric):
    log_probs: List[Tensor]
    targets: List[Tensor]

    full_state_update: bool = False

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

        self.add_state("log_probs", [], dist_reduce_fx="cat")
        self.add_state("targets", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `NLLMetric` will save all targets and predictions"
            " in buffer. For large datasets this may lead to large memory"
            " footprint."
        )

    def update(self, probs: Tensor, target: Tensor) -> None:  # type: ignore
        self.log_probs.append(torch.log(probs))
        self.targets.append(target)

    def compute(self) -> Tensor:
        log_probs = dim_zero_cat(self.log_probs)
        targets = dim_zero_cat(self.targets)
        return F.nll_loss(log_probs, targets)


class GalNLLMetric(Metric):
    predictions: List[Tensor]
    targets: List[Tensor]
    predicted_vars: List[Tensor]

    full_state_update: bool = False

    def __init__(
        self, reduction: str = "mean", epsilon: float = 1e-8, **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.reduction = reduction
        self.add_eps = epsilon

        self.add_state("predictions", [], dist_reduce_fx="cat")
        self.add_state("predicted_vars", [], dist_reduce_fx="cat")
        self.add_state("targets", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `GalNLLMetric` will save all targets and predictions in"
            "buffer. For large datasets this may lead to large memory"
            "footprint."
        )

    def update(self, prediction: Tensor, target: Tensor, var: Tensor) -> None:  # type: ignore
        self.predictions.append(prediction)
        self.targets.append(target)
        self.predicted_vars.append(var)

    def compute(self) -> Tensor:
        predictions = dim_zero_cat(self.predictions)
        targets = dim_zero_cat(self.targets)
        predicted_vars = dim_zero_cat(self.predicted_vars)
        nlls = (
            (targets - predictions) ** 2 / (2 * predicted_vars + self.add_eps)
            + torch.log(predicted_vars + self.add_eps) / 2
            + np.log(2 * torch.pi) / 2
        )
        proba = torch.exp(-1 * nlls)
        return (-1 * torch.log(proba.mean(axis=0))).mean()


class GaussianNLLMetric(Metric):
    predictions: List[Tensor]
    targets: List[Tensor]
    predicted_vars: List[Tensor]

    full_state_update: bool = False

    def __init__(
        self, reduction: str = "mean", full: bool = False, **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.reduction = reduction
        self.full = full

        self.add_state("predictions", [], dist_reduce_fx="cat")
        self.add_state("predicted_vars", [], dist_reduce_fx="cat")
        self.add_state("targets", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `GaussianNLLMetric` will save all targets and predictions"
            " in buffer. For large datasets this may lead to large memory"
            " footprint."
        )

    def update(self, prediction: Tensor, target: Tensor, var: Tensor) -> None:  # type: ignore
        self.predictions.append(prediction)
        self.targets.append(target)
        self.predicted_vars.append(var)

    def compute(self) -> Tensor:
        predictions = dim_zero_cat(self.predictions)
        targets = dim_zero_cat(self.targets)
        predicted_vars = dim_zero_cat(self.predicted_vars)
        return F.gaussian_nll_loss(
            predictions,
            targets,
            predicted_vars,
            reduction=self.reduction,
            full=self.full,
        )


class FPR95Metric(Metric):
    """Class which computes the False Positive Rate at 95% Recall."""

    is_differentiable: bool = False
    higher_is_better: bool = False
    full_state_update: bool = False

    conf: List[Tensor]
    targets: List[Tensor]

    def __init__(self, pos_label: int = None, **kwargs) -> None:
        super().__init__(**kwargs)

        self.pos_label = pos_label
        self.add_state("conf", [], dist_reduce_fx="cat")
        self.add_state("targets", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `FPR95Metric` will save all targets and predictions"
            " in buffer. For large datasets this may lead to large memory"
            " footprint."
        )

    def update(self, conf: Tensor, target: Tensor) -> None:  # type: ignore
        self.conf.append(conf)
        self.targets.append(target)

    def compute(self) -> Tensor:
        r"""From https://github.com/hendrycks/anomaly-seg

        Compute the actual False Positive Rate at 95% Recall.

        Returns:
            Tensor: The value of the FPR95.
        """
        conf = dim_zero_cat(self.conf).cpu().numpy()
        targets = dim_zero_cat(self.targets).cpu().numpy()

        # out_labels is an array of 0s and 1s - 0 if IOD 1 if OOD
        out_labels = targets == self.pos_label

        in_scores = conf[np.logical_not(out_labels)]
        out_scores = conf[out_labels]

        # pos = OOD
        neg = np.array(in_scores[:]).reshape((-1, 1))
        pos = np.array(out_scores[:]).reshape((-1, 1))
        examples = np.squeeze(np.vstack((pos, neg)))
        labels = np.zeros(len(examples), dtype=np.int32)
        labels[: len(pos)] += 1

        # make labels a boolean vector, True if OOD
        labels = labels == self.pos_label

        # sort scores and corresponding truth values
        desc_score_indices = np.argsort(examples, kind="mergesort")[::-1]
        examples = examples[desc_score_indices]
        labels = labels[desc_score_indices]

        # examples typically has many tied values. Here we extract
        # the indices associated with the distinct values. We also
        # concatenate a value for the end of the curve.
        distinct_value_indices = np.where(np.diff(examples))[0]
        threshold_idxs = np.r_[distinct_value_indices, labels.shape[0] - 1]

        # accumulate the true positives with decreasing threshold
        tps = stable_cumsum(labels)[threshold_idxs]
        fps = 1 + threshold_idxs - tps  # add one because of zero-based indexing

        thresholds = examples[threshold_idxs]

        recall = tps / tps[-1]

        last_ind = tps.searchsorted(tps[-1])
        sl = slice(last_ind, None, -1)  # [last_ind::-1]
        recall, fps, tps, thresholds = (
            np.r_[recall[sl], 1],
            np.r_[fps[sl], 0],
            np.r_[tps[sl], 0],
            thresholds[sl],
        )

        cutoff = np.argmin(np.abs(recall - 0.95))

        return torch.as_tensor(
            fps[cutoff] / (np.sum(np.logical_not(labels))), dtype=torch.float32
        )


class TopVSecond(Metric):
    full_state_update: bool = False
    is_differentiable: bool = True
    higher_is_better: bool = True

    crit: List[Tensor]

    def __init__(
        self,
        reduction: str = "mean",
        logits: bool = False,
        eps_add: float = 1e-6,
        eps_mul: float = 1,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        self.reduction = reduction
        self.use_logits = logits
        self.eps_add = eps_add
        self.eps_mul = eps_mul

        self.add_state("crit", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `TopVSecond` will save all predictions"
            " in buffer. For large datasets this may lead to large memory"
            " footprint."
        )

    def update(self, crit: Tensor) -> None:  # type: ignore
        # store data as (example, class)
        self.crit.append(crit)

    def compute(self) -> Tensor:
        crit = dim_zero_cat(self.crit)

        if self.use_logits:
            crit = F.softplus(self.eps_mul * crit)

        top2, _ = crit.topk(2, dim=-1)

        top_vs_second = top2[:, 0] / (top2[:, 1] + self.eps_add)

        if self.reduction == "mean":
            top_vs_second = top_vs_second.mean()
        elif self.reduction == "sum":
            top_vs_second = top_vs_second.sum()

        return top_vs_second - 1


class DisagreementMetric(Metric):
    full_state_update: bool = False
    probs: List[Tensor]

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

        self.add_state("probs", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `DisagreementMetric` will save all targets and predictions"
            " in buffer. For large datasets this may lead to large memory"
            " footprint."
        )

    def update(self, probs: Tensor) -> None:  # type: ignore
        # store data as (example, estimator, class)
        self.probs.append(probs.transpose(0, 1))

    def _compute_disagreement(self, classes: Tensor) -> Tensor:
        r"""Computes the disagreement between the predicted classes among
        all pairs of estimators.

        Args:
            classes (Tensor): Classes predicted by the `n_estimators`
                estimators.

        Returns:
            Tensor: Mean disagreement between estimators.
        """
        # TODO: Using onehot might be memory intensive
        n_estimators = classes.shape[-1]
        counts = torch.sum(F.one_hot(classes), dim=1)
        potential_counts = n_estimators * (n_estimators - 1) / 2
        return torch.mean(
            1 - (counts * (counts - 1) / 2).sum(dim=1) / potential_counts
        )

    def compute(self) -> Tensor:
        probs = dim_zero_cat(self.probs)
        classes = probs.argmax(dim=-1)
        return self._compute_disagreement(classes)


class VariationRatio(Metric):
    """From https://proceedings.mlr.press/v70/gal17a/gal17a.pdf"""

    full_state_update: bool = False
    is_differentiable: bool = True
    higher_is_better: bool = False

    probs: List[Tensor]

    def __init__(
        self, probabilistic: bool = True, reduction: str = "mean", **kwargs
    ) -> None:
        super().__init__(**kwargs)

        self.probabilistic = probabilistic
        self.reduction = reduction

        self.add_state("probs", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `VariationRatio` will save all predictions in buffer. For "
            " large datasets this may lead to large memory footprint."
        )

    def update(self, probs: Tensor) -> None:  # type: ignore
        # store data as (example, estimator, class)
        self.probs.append(probs.transpose(0, 1))

    def compute(self) -> Tensor:
        r"""Computes the variation ratio which amount to the proportion of
        predicted class labels which are not the chosen class.

        Returns:
            Tensor: Mean disagreement between estimators.
        """
        probs_per_est = dim_zero_cat(self.probs)
        n_estimators = probs_per_est.shape[1]
        probs = probs_per_est.mean(dim=1)

        # best class for exemple
        max_classes = probs.argmax(dim=-1)

        if self.probabilistic:
            probs_per_est = probs_per_est.permute((0, 2, 1))
            variation_ratio = 1 - probs_per_est[
                torch.arange(probs_per_est.size(0)), max_classes
            ].mean(dim=1)
        else:
            # best class for (exemple, estimator)
            max_classes_per_est = probs_per_est.argmax(dim=-1)
            variation_ratio = (
                1
                - torch.sum(
                    max_classes_per_est == max_classes.unsqueeze(1), dim=-1
                )
                / n_estimators
            )

        if self.reduction == "mean":
            variation_ratio = variation_ratio.mean()
        elif self.reduction == "sum":
            variation_ratio = variation_ratio.sum()

        # print(variation_ratio, variation_ratio.shape)
        return variation_ratio


class JensenShannonDivergence(Metric):
    full_state_update: bool = False

    probs: List[Tensor]

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

        self.add_state("probs", [], dist_reduce_fx="cat")
        self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)

        rank_zero_warn(
            "Metric `JensenShannonDivergence` will save all "
            "predictions in buffer. For large datasets this may lead to large "
            "memory footprint."
        )

    def update(self, probs: Tensor) -> None:  # type: ignore
        # store data as (example, estimator, class)
        self.probs.append(probs.transpose(0, 1))

    def compute(self) -> Tensor:
        probs = dim_zero_cat(self.probs)
        mean_proba = probs.mean(1, keepdim=True).repeat(1, probs.shape[1], 1)

        return (
            F.kl_div(
                mean_proba.log(),
                probs.log(),
                log_target=True,
                reduction="batchmean",
            )
            / probs.shape[1]
        )


class Entropy(Metric):
    r"""The Shannon Entropy to estimate the confidence of the estimators.

    A higher entropy means a lower confidence.
    """
    full_state_update: bool = False

    probs: List[Tensor]

    def __init__(
        self, over_estimators: bool = False, reduction: str = "mean", **kwargs
    ) -> None:
        super().__init__(**kwargs)

        self.over_estimators = over_estimators
        self.reduction = reduction
        self.add_state("probs", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `Entropy` will save all predictions in buffer."
            "For large datasets this may lead to a large memory footprint."
        )

    def update(self, probs: Tensor) -> None:  # type: ignore
        # store data as (example, class) or (example, estimator, class)

        if self.over_estimators:
            # As (example, estimator, class)
            probs = probs.transpose(0, 1)

        self.probs.append(probs)

    def compute(self) -> Tensor:
        r"""Computes the entropy on the data.

        note:
        If :attr:`over_estimators`, computes the mean on the data of the
        entropy of the mean on the estimators of the probabilities predicted
        for each class.

        If not :attr:`over_estimators`, computes the mean on the data and the
        estimators of the per estimator-entropy.


        Returns:
            Tensor: The total entropy.
        """
        probs = dim_zero_cat(self.probs)

        entropy_product = torch.log(probs) * probs
        entropy = -entropy_product.sum(dim=-1)

        if self.over_estimators:
            entropy = entropy.mean(-1)

        if self.reduction == "mean":
            entropy = entropy.mean()
        elif self.reduction == "sum":
            entropy = entropy.sum()

        return entropy


class MutualInformation(Metric):
    r"""The Mutual Information to estimate the epistemic uncertainty.

    A higher mutual information means a higher uncertainty.
    """
    full_state_update: bool = False

    probs_per_est: List[Tensor]

    def __init__(self, reduction: str = "mean", **kwargs) -> None:
        super().__init__(**kwargs)

        self.reduction = reduction
        self.entropy_over_estimators = Entropy(over_estimators=True)
        self.entropy = Entropy()
        self.add_state("probs_per_est", [], dist_reduce_fx="cat")

        rank_zero_warn(
            "Metric `MutualInformation` will save and predictions "
            "in buffer. For large datasets this may lead to a large memory "
            "footprint."
        )

    def update(self, probs_per_est: Tensor) -> None:  # type: ignore
        # store data as (example, estimator, class)
        self.probs_per_est.append(probs_per_est.transpose(0, 1))

    def compute(self) -> Tensor:
        r"""Computes the mutual information on the data.

        Returns:
            Tensor: The total mutual information.
        """

        # convert data to (estimator, example, class)
        probs_per_est = dim_zero_cat(self.probs_per_est).transpose(0, 1)
        probs = probs_per_est.mean(dim=0)

        # Entropy of the mean over the estimators
        entropy_product = torch.log(probs) * probs
        entropy_mean = -entropy_product.sum(dim=-1)

        # Mean over the estimators of the entropy over the classes
        entropy_product = torch.log(probs_per_est) * probs_per_est
        mean_entropy = (-entropy_product.sum(dim=-1)).mean(dim=0)

        mutual_information = entropy_mean - mean_entropy

        if self.reduction == "mean":
            mutual_information = mutual_information.mean()
        elif self.reduction == "sum":
            mutual_information = mutual_information.sum()

        return mutual_information
