from typing import Sequence, TypeVar, Tuple, Optional
from scipy.stats import binomtest

import torch
from torch import nn

from ..transforms import Transform

IT = TypeVar("IT")
OT = TypeVar("OT")


def standard_predict(
    counts: Sequence[int], threshold: Optional[float] = None
) -> Tuple[int, float]:
    num_classes = len(counts)
    if num_classes > 2 and threshold is not None:
        raise ValueError("Only supports explicit threshold for 2 class problems")

    counts = torch.as_tensor(counts, device="cpu")

    if num_classes == 2:
        threshold = threshold if threshold is not None else 0.5
        counts_total = counts.sum().item()
        pred = (counts[1].item() / counts_total > threshold) * 1
        # You have to test the two cases separately (depending on the predicted class)
        # Otherwise the result is always abstain if pred is 0
        if pred:
            test = binomtest(
                counts[1].item(), counts_total, p=threshold, alternative="greater"
            )
        else:
            test = binomtest(
                counts[0].item(), counts_total, p=1 - threshold, alternative="greater"
            )
    else:
        toptwo = torch.topk(counts, 2, sorted=True)
        toptwo_counts, toptwo_classes = toptwo.values.cpu(), toptwo.indices.cpu()

        pred = toptwo_classes[0].item()

        n_A = toptwo_counts[0].item()
        n_B = toptwo_counts[1].item()
        test = binomtest(n_A, n_A + n_B, p=0.5, alternative="two-sided")

    return pred, test.pvalue


class RandomPerturbation(nn.Module, Transform[IT, OT]):
    """Base class for a random perturbation"""

    def __init__(self, threshold: Optional[float] = None):
        super().__init__()
        if threshold is not None:
            threshold = torch.tensor(threshold)
        self.register_buffer("threshold", threshold)

    def forward(self, input: IT) -> OT:
        """
        Args:
            input: Input to be perturbed.

        Returns:
            Perturbed output.
        """
        pass

    def __repr__(self):
        return self.__class__.__name__ + "()"

    def predict(
        self,
        input: IT,
        counts: Sequence[int],
        **kwargs,
    ) -> Tuple[int, float]:
        """Compute the predicted class for an input to a classifier smoothed under this perturbation

        Args:
            input: Unperturbed input.
            counts: Class frequencies for randomly perturbed inputs passed through the classifier. Must be a sequence
                where `counts[i]` is the number of perturbed inputs with class index `i`.

        Keyword args:
            **kwargs: Other keyword arguments used in derived classes.

        Returns:
            The predicted class index and the p-value
        """
        return standard_predict(counts, self.threshold)

    def certified_radius(
        self, input: IT, pred: int, counts: Sequence[int], alpha: float = 0.05, **kwargs
    ) -> Tuple[float, float]:
        """Compute the certified radius for an input to a classifier smoothed under this perturbation

        Args:
            input: Unperturbed input.
            pred: Estimated prediction of the smoothed classifier for `input`. Must be a class index in the set
                {0, 1, 2, ..., n_classes - 1}.
            counts: Class frequencies for randomly perturbed inputs passed through the classifier. Must be a sequence
                where `counts[i]` is the number of perturbed inputs with class index `i`.

        Keyword args:
            alpha: Significance level. Defaults to 0.05.
            **kwargs: Other keyword arguments used in derived classes.

        Returns:
            Lowerbound of the most probable class and the largest certified radius for the input.
        """
        pass

    def extra_dim(self):
        return 0
