import warnings
import abc
from typing import Optional, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..metadata import Metadata
from ..models import EmbeddingClassifier
from ..models.utils import conv1d_l_out, ragged_len, window_pad
from ..utils import collate_pad, inv_softmax, inv_collate_pad
from .perturbation import RandomPerturbation


class CertifiedEmbeddingClassifier(EmbeddingClassifier):
    """
    A certified model comprising an embedding layer followed by a perturbation layer.
    """

    REDUCES = dict(
        none=-1,
        hard=0,
        soft=1,
    )
    _INV_REDUCES = {v: k for k, v in REDUCES.items()}

    def __init__(
        self,
        perturbation: RandomPerturbation,
        out_size: int = 2,
        embed_num: int = 257,
        embed_size: int = 8,
        scale_grad_by_freq: bool = False,
        reduce: str = "soft",
        certify_threshold: Optional[torch.Tensor] = None,
        debug_seed: Optional[int] = None,
    ):
        """Initialize a certified embedding classifier where perturbation is applied after embedding

        Args:
            perturbation (RandomPerturbation): The perturbation to apply to inputs
            out_size (int, optional): Size of the final output layer.
            embed_num (int, optional): Number of possible tokens. Defaults to 257.
            embed_size (int, optional): Size of the embedding dimension. Defaults to 8.
            scale_grad_by_freq (bool, optional): If gradient should be scaled by frequency. Defaults to False.
            reduce (str, optional): Reduction method for aggregating results. Defaults to "soft".
            certify_threshold (torch.Tensor, optinal): The threshold used by the perturbation, defaults to None

        Raises:
            ValueError: If an invalid reduction method is provided.
        """
        super().__init__(
            out_size=out_size,
            embed_num=embed_num,
            embed_size=embed_size,
            scale_grad_by_freq=scale_grad_by_freq,
        )
        self.perturbation = perturbation
        self.register_buffer(
            "_reduce",
            torch.tensor(self.REDUCES.get(reduce.lower(), self.REDUCES["soft"])),
        )
        self.register_buffer("certify_threshold", certify_threshold)
        # Set it for perturbation
        self.perturbation.threshold = self.certify_threshold
        self.debug_seed = debug_seed

    @property
    def reduce(self):
        return self._INV_REDUCES[self._reduce.item()]

    @reduce.setter
    def reduce(self, val: str):
        self._reduce = torch.tensor(self.REDUCES[val.lower()])

    def perturb(self, x: torch.Tensor, **kwargs):
        if self.debug_seed is not None:
            torch.random.manual_seed(self.debug_seed)
        return self.perturbation(x)

    def embedd_and_forward(
        self,
        x: torch.Tensor,
        num_samples: Optional[int] = None,
        return_logits: Optional[bool] = True,
        return_radii: Optional[bool] = False,
        batch_size: Optional[int] = None,
        forward_kwargs: Dict = dict(),
        certify_kwargs: Dict = dict(),
    ) -> torch.Tensor:
        """Compute logits from embeddings.
            If reduction method is 'hard', then the result will be categorical.
            If reduction method is' soft', then it will be logits.
            If reduction method is 'none', then this will return logits of
            each single stochastic forward logits in a 3d tensor of [num_samples, batch_size, num_classes]

        Args:
            x (torch.Tensor): The input data
            num_samples (int): The number of samples to estimate with. If None, then only one stochastic sample is taken.
            return_logits (Optional[bool], optional): Whether or not to return logit or probabilities,
                default to return logits. Logits is defined as the value where if you take softmax, you get probabilities.
                If return_logits is false, then the behavior depends on the specified aggregation function.
            return_radii (Optional[bool], optional): If the radii should be computed and returned.
                The radii is computed prior to converting output to logits.
            batch_size (Optional[int], optional): The maximum batch size when computing the output.
                This value has to be greater than the input count in x.
                The actual batch will be considered a the maximum multiple of input count.
                So if x has shape [10, n], and batch is 32, then actual batch will be 10 * 3 = 30.
                Set to None to disable (do everything in one forward pass). Defaults to None.
            forward_kwargs (Optional[Dict], optional): (Metadata, etc): Other arguments will be passed into the perturbation class
            certify_kwargs (Optional[Dict], optional): Kwargs to be passed in for radii computation.


        Returns:
            torch.Tensor: Matrix representing the logit of each class
        """
        if return_logits and self.reduce != "none" and num_samples is not None and num_samples > 1:
            warnings.warn(
                "Warning: returning logits while the reduction method is not none makes the logits not unique."
            )

        # Set default for sample size and batch size
        num_samples = num_samples if num_samples is not None else 1
        batch_size = batch_size if batch_size is not None else num_samples * x.size(0)

        if self.reduce == "none":
            out = torch.zeros(
                [num_samples, x.size(0), self.out_size],
                device=x.device,
                dtype=torch.float,
            )
        else:
            out = torch.zeros(
                [x.size(0), self.out_size],
                device=x.device,
                dtype=torch.float,
            )

        # work out how many repeats of the input x we can fit in one batch
        repeats = batch_size // x.size(0)
        num_samples_remain = num_samples
        repeat_idx = 0
        while num_samples_remain > 0:
            # The actual number of repeats should be capped by remaining samples
            this_repeats = min(repeats, num_samples_remain)
            # Decrease the num_samples_remain
            num_samples_remain -= this_repeats

            # Duplicate x, args, kwargs
            # The indexing for _x will be [repeat * id_in_batch, ...], repeat appears before id_in_batch
            _x = x.expand(this_repeats, *list(x.size())).reshape(
                this_repeats * x.size(0), *list(x.size()[1:])
            )
            _forward_kwargs = {}
            for key, arg in forward_kwargs.items():
                _forward_kwargs[key] = collate_pad([arg] * this_repeats, stack=False)

            # Get the probability of this batch
            _x = self.perturb(_x, **_forward_kwargs)
            # Have to pad it again (because of perturbation)
            _x = window_pad(_x, self.window_size)
            _out = self._perturbed_forward(_x).reshape(
                this_repeats, x.size(0), self.out_size
            )

            if self.reduce == "none" and return_logits:
                # Directly store logits if we do not aggregate
                out[repeat_idx : (repeat_idx + this_repeats)] = _out
            else:
                # If we aggregate or return probabilities, then we convert to probs first
                _out = F.softmax(_out, dim=-1)
                if self.reduce == "none":
                    out[repeat_idx : (repeat_idx + this_repeats)] = _out
                elif self.reduce == "soft":
                    # If soft reduction, we just aggregate the probabilities
                    out += torch.sum(_out, dim=0)
                elif self.reduce == "hard":
                    out += torch.sum(self._base_proba_reduce(_out), dim=0)
                else:
                    raise ValueError("Unknown aggregation mode.")
            repeat_idx += this_repeats

        # Compute radii if required
        if return_radii:
            certified_radii = self._certified_radius(
                x=x,
                counts=torch.round(out).int(),
                forward_kwargs=forward_kwargs,
                **certify_kwargs,
            )

        # Normalize the output vector
        if not self.reduce == "none":
            # just divide by number of samples
            out = out / num_samples

        # We only need consider logits inversion if the mode is not None.
        # Since None mode store logits directly
        if return_logits and self.reduce != "none":
            if self.reduce == "hard":
                warnings.warn(
                    "Warning: reduce mode is 'none', returning logits "
                    "will result in inf values."
                )
            out = inv_softmax(out)
        
        if return_radii:
            return (out, certified_radii)
        #print("\t probs:", F.softmax(out, dim=-1) if return_logits else out)
        #t = torch.cuda.get_device_properties(0).total_memory /1e9
        #r = torch.cuda.memory_reserved(0) /1e9
        #a = torch.cuda.memory_allocated(0) /1e9
        #f = r-a
        #print(f"total: {t}, reserved: {r}, allocated: {a}, free: {f}")
        return out

    @abc.abstractmethod
    def _perturbed_forward(self, x: torch.Tensor) -> torch.Tensor:
        """Performs the forward operation on an already perturbed batch of samples

        Args:
            x (torch.Tensor): A batch of samples that is already perturbed
        """
        pass

    @abc.abstractmethod
    def _base_proba_reduce(self, probs: torch.Tensor) -> torch.Tensor:
        """This should perform voting on the last dimension of the probability vector

        Args:
            probs (torch.Tensor): A tensor shape of [..., num_classes]
        """
        pass

    def forward(
        self,
        x: torch.IntTensor,
        num_samples: Optional[int] = None,
        return_logits: Optional[bool] = True,
        return_radii: Optional[bool] = False,
        batch_size: Optional[int] = None,
        forward_kwargs: Dict = dict(),
        certify_kwargs: Dict = dict(),
    ) -> torch.Tensor:
        """Compute logits from embeddings.
            If reduction method is 'hard', then the result will be categorical.
            If reduction method is' soft', then it will be logits.
            If reduction method is 'none', then this will return logits of
            each single stochastic forward logits in a 3d tensor
            of [num_samples, batch_size, num_classes]

            Though it returns logits instead of probabilities. If radii is computed,
            it will be properly computed with probabilities.

        Args:
            x (torch.Tensor): The input byte tensor
            num_samples (int): The number of samples to estimate with. If None, then only one stochastic sample is taken.
            return_logits (Optional[bool], optional): Whether or not to return logit or probabilities,
                default to return logits. Logits is defined as the value where if you take softmax, you get probabilities.
                If return_logits is false, then the behavior depends on the specified aggregation function.
            return_radii (Optional[bool], optional): If the radii should be computed and returned.
                The radii is computed prior to converting output to logits.
            batch_size (Optional[int], optional): The batch size when computing the output. Set to None to disable. Defaults to None.
            forward_kwargs (Optional[Dict], optional): (Metadata, etc): Other arguments will be passed into the perturbation class
            certify_kwargs (Optional[Dict], optional): Kwargs to be passed in for radii computation.

        Returns:
            torch.Tensor: The predicted result
        """
        assert (
            x.ndim >= 2
        ), "Input must contain a batch dimension, the current input has shape" + str(
            x.size()
        )
        # Each sample in x is a sequence of integer tokens in the set {1, ..., embed_num - 1} padded with zeroes
        x = x.int()
        x = window_pad(x, self.window_size)
        x = self.embed(x)
        return self.embedd_and_forward(
            x,
            num_samples=num_samples,
            return_logits=return_logits,
            return_radii=return_radii,
            batch_size=batch_size,
            forward_kwargs=forward_kwargs,
            certify_kwargs=certify_kwargs,
        )

    def certify(
        self,
        x: torch.IntTensor,
        num_samples_pred: Optional[int] = None,
        num_samples_bound: Optional[int] = None,
        batch_size: Optional[int] = None,
        forward_kwargs: Dict = dict(),
        certify_kwargs: Dict = dict(),
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        preds, _ = self.predict(
            x=x,
            num_samples=num_samples_pred,
            batch_size=batch_size,
            forward_kwargs=forward_kwargs,
            return_pval=True,
        )
        bound_probs, radii = self.forward(
            x=x,
            num_samples=num_samples_bound,
            return_logits=False,
            return_radii=True,
            batch_size=batch_size,
            forward_kwargs=forward_kwargs,
            certify_kwargs=certify_kwargs,
        )
        bound_preds = self.proba_reduce(
            bound_probs, num_samples=num_samples_bound, return_pval=False
        )
        radii = torch.where(preds == bound_preds, radii, torch.tensor(0, dtype=torch.float32, device=radii.device))
        return preds, radii

    def _certified_radius(
        self,
        x: torch.Tensor,
        counts: torch.IntTensor,
        alpha=0.05,
        forward_kwargs: Dict = dict(),
        **kwargs,
    ):
        if self.reduce == "none":
            raise ValueError(
                "Reduction mode is None, please specify either 'soft' or 'hard' to compute radius"
            )
        else:
            if self.reduce == "soft":
                warnings.warn(
                    "Certified radius for 'soft' reduction is not implemented yet. "
                    "Defaults to the radii of 'hard' reduction."
                )
            radii = torch.empty(x.size(0), dtype=torch.float32, device=counts.device)
            forward_kwargs = inv_collate_pad(forward_kwargs, pad_value=False)
            if len(forward_kwargs) == 0:
                forward_kwargs = [{}] * x.size(0)
            for i, (_x, _forward_kwargs, _counts) in enumerate(
                zip(x, forward_kwargs, counts)
            ):
                pred, _ = self.perturbation.predict(input=_x, counts=_counts)
                radii[i] = self.perturbation.certified_radius(
                    input=_x,
                    pred=pred,
                    counts=_counts,
                    alpha=alpha,
                    **kwargs,
                    **_forward_kwargs,
                )
        return radii

    def predict(
        self,
        x: torch.IntTensor,
        num_samples: Optional[int] = None,
        batch_size: Optional[int] = None,
        forward_kwargs: Dict = dict(),
        return_pval: bool = False,
    ) -> torch.Tensor:
        """Predict the class of the batch of input data

        Args:
            x (torch.Tensor): The input data
            num_samples (int): The number of samples to estimate with. If None, then only one stochastic sample is taken.
            batch_size (Optional[int], optional): The batch size when computing the output. Set to None to disable. Defaults to None.
            forward_kwargs (Optional[Dict], optional): (Metadata, etc): Other arguments will be passed into the perturbation class
            return_pval (bool, optional): If pvalue should also be returned as a tuple. Default to False

        Returns:
            torch.Tensor: Categorical output of the classes the inputs are classified as
        """
        probs = self.predict_proba(
            x=x,
            num_samples=num_samples,
            batch_size=batch_size,
            forward_kwargs=forward_kwargs,
        )
        return self.proba_reduce(
            probs, num_samples=num_samples, return_pval=return_pval
        )

    def predict_proba(
        self,
        x: torch.IntTensor,
        num_samples: Optional[int] = None,
        batch_size: Optional[int] = None,
        forward_kwargs: Dict = dict(),
    ) -> torch.Tensor:
        """Get the predicted probabilities for each class after softmax

        Args:
            x (torch.Tensor): The input data

        Returns:
            torch.Tensor: Matrix representing the probability of each class
        """
        proba = self.forward(
            x=x,
            num_samples=num_samples,
            return_logits=False,
            return_radii=False,
            batch_size=batch_size,
            forward_kwargs=forward_kwargs,
        )
        return proba

    def proba_reduce(
        self,
        probs: torch.Tensor,
        num_samples: Optional[int] = None,
        return_pval: bool = False,
    ) -> torch.Tensor:
        """Reduce probabilities to predictions with custom rule

        Args:
            probs (torch.Tensor): The probabilities of prediction
            num_samples (int): The number of samples to estimate with. If None, then we estimate with a large sample size
            return_pval (bool, optional): If pvalue should also be returned as a tuple. Default to False

        Returns:
            torch.Tensor: The predicted classes of each probabilities.
        """
        if self.reduce == "none":
            raise ValueError(
                "Reduction mode is None, please specify either 'soft' or 'hard' to for proper reduction"
            )
        else:
            if num_samples is None:
                num_samples = 10000
            preds = torch.empty(probs.size(0), dtype=torch.float32, device=probs.device)
            pvals = torch.empty(probs.size(0), dtype=torch.float32, device=probs.device)
            for i, prob in enumerate(probs):
                counts = torch.round(prob * num_samples).int()
                preds[i], pvals[i] = self.perturbation.predict(
                    input=None,
                    counts=counts,
                )
        if return_pval:
            return preds, pvals
        return preds


class CertifiedMalConv(CertifiedEmbeddingClassifier):
    def __init__(
        self,
        perturbation: RandomPerturbation,
        out_size: int = 2,
        channels: int = 128,
        window_size: int = 512,
        embed_num: int = 257,
        embed_size: int = 8,
        scale_grad_by_freq: bool = False,
        threshold: Optional[torch.Tensor] = None,
        certify_threshold: Optional[torch.Tensor] = None,
        reduce: str = "soft",
    ):
        super().__init__(
            perturbation=perturbation,
            out_size=out_size,
            embed_num=embed_num,
            embed_size=embed_size,
            scale_grad_by_freq=scale_grad_by_freq,
            reduce=reduce,
            certify_threshold=certify_threshold,
        )
        self.channels = channels
        self.window_size = window_size
        self.conv_1 = nn.Conv1d(
            embed_size, channels, window_size, stride=window_size, bias=True
        )
        self.conv_2 = nn.Conv1d(
            embed_size, channels, window_size, stride=window_size, bias=True
        )
        self.pooling = nn.AdaptiveMaxPool1d(1)
        self.fc_1 = nn.Linear(channels, channels)
        self.fc_2 = nn.Linear(channels, out_size)

        if self.out_size > 2 and threshold is not None:
            raise ValueError(
                "Thresholding for multi-class classification is not supported (yet)."
            )

        self.register_buffer("threshold", threshold)

    def perturb(self, x: torch.Tensor, metadata: Optional[Metadata] = None):
        return super().perturb((x, metadata))

    def _perturbed_forward(
        self, x: torch.Tensor, l_in: Optional[torch.LongTensor] = None
    ) -> torch.Tensor:
        """Performs the forward operation on an already perturbed batch of samples

        Args:
            x (torch.Tensor): A batch of samples that is already perturbed
            l_in (Optional[torch.LongTensor], optional): The ragged length of inputs
        """
        if l_in is None:
            l_in = ragged_len(x)
            l_in = torch.ceil(l_in / self.window_size).int() * self.window_size

        x = torch.transpose(x, -1, -2)

        cnn_value = self.conv_1(x)
        gating_weight = torch.sigmoid(self.conv_2(x))
        x = cnn_value * gating_weight

        l_out = x.size(2)
        # If the Conv1d layers were compatible with ragged tensors, L_out would vary for each sample in the batch,
        # as computed below.
        correct_l_out = conv1d_l_out(l_in, self.window_size, stride=self.window_size)

        # To produce the correct result, we need to ignore any extraneous dimensions (which correspond to padding)
        # when pooling.
        # Mask of shape (N, L_out)
        mask = torch.arange(l_out, device=x.device) >= correct_l_out.unsqueeze(1)
        # Repeat across channels to produce mask of shape (N, C_out, L_out)
        mask = mask.unsqueeze(1).expand(x.size())
        # Set extraneous dimensions to neg inf, since they'll be ignored when max pooling
        x[mask] = -torch.inf
        x = self.pooling(x)

        # Flatten
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

    def _base_proba_reduce(self, probs: torch.Tensor) -> torch.Tensor:
        """This should perform voting on the last dimension of the probability vector

        Args:
            probs (torch.Tensor): A tensor shape of [..., num_classes]
        """
        _probs = probs.clone()
        if self.out_size == 2 and self.threshold is not None:
            _probs[..., 0] -= (1 - self.threshold)
            _probs[..., 1] -= self.threshold
        return DiffHardMax.apply(_probs)

    def _certified_radius(
        self,
        x: torch.Tensor,
        counts: torch.IntTensor,
        alpha=0.05,
        forward_kwargs=None,
        **kwargs,
    ):
        if self.reduce == "none":
            raise ValueError(
                "Reduction mode is None, please specify either 'soft' or 'hard' to compute radius"
            )
        else:
            if self.reduce == "soft":
                warnings.warn(
                    "Certified radius for 'soft' reduction is not properly implemented yet. "
                    "Defaults to the radii of 'hard' reduction using soft probabilities."
                )
            radii = torch.empty(x.size(0), dtype=torch.float32, device=counts.device)
            if forward_kwargs is not None:
                forward_kwargs = inv_collate_pad(forward_kwargs, pad_value=False, batch_size=x.size(0))
            else:
                forward_kwargs = [{}] * x.size(0)
            for i, (_x, _forward_kwargs, _counts) in enumerate(
                zip(x, forward_kwargs, counts)
            ):
                try:
                    _x = (_x, _forward_kwargs["metadata"])
                except:
                    _x = (_x, None)
                pred, _ = self.perturbation.predict(input=_x, counts=_counts)
                radii[i] = self.perturbation.certified_radius(
                    input=_x,
                    pred=pred,
                    counts=_counts,
                    alpha=alpha,
                    **kwargs,
                )
            return radii


class DiffHardMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        idx = torch.argmax(input, dim=-1, keepdim=True)
        output = torch.zeros_like(input)
        output.scatter_(-1, idx, 1)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output
