from math import ceil, nan, comb
from typing import Callable, Sequence, Tuple, Optional, Union, OrderedDict

import numpy as np
import torch
from statsmodels.stats.proportion import proportion_confint

from ..transforms import functional as F
from ..types import IntBinarySample, IntBinary, EmbedBinary, EmbedBinarySample
from .perturbation import RandomPerturbation
from .utils import binary_search_solve, brute_force_solve, combln



def _lecuyer_cert(
    num_dim: int,
    num_mask: int,
    counts: Sequence[int],
    alpha: float = 0.05,
) -> Callable[[int], float]:
    """Approximate certificate for randomized masking based on the statistical test of Lecuyer et al. (2019)

    Args:
        num_dim: Number of dimensions in the input.
        num_mask: Number of dimensions randomly masked in the input.
        counts: Class frequencies for randomly perturbed inputs passed through the classifier. Must be a sequence
            where `counts[i]` is the number of perturbed inputs with class index `i`.

    Keyword args:
        alpha: Significance level. Defaults to 0.05.

    Returns:
        A decreasing function of the L0 radius. If function is positive for a particular radius, then the certificate
        holds.
    """
    # Make a Bonferonni correction since we will run two tests for each instance
    alpha = alpha / 2

    # One-sided lower bound on probability of most frequent class (\underbar{p_A})
    # By setting `method ="beta"` we are computing an exact Clopper-Pearson interval.
    top2_ind = np.argpartition(counts, -2)[-2:]
    top2_counts = np.sort(counts[top2_ind])
    num_samples = np.sum(counts)
    p_A_lower, _ = proportion_confint(
        top2_counts[-1], num_samples, alpha=2 * alpha, method="beta"
    )

    # One-sided upper bound on probability of 2nd most frequent class (\overbar{p_B})
    _, p_B_upper = proportion_confint(
        top2_counts[-2], num_samples, alpha=2 * alpha, method="beta"
    )

    def f(radius: int) -> float:
        delta = 1.0 - np.exp(
            combln(num_dim - radius, num_mask - radius) - combln(num_dim, num_mask)
        )
        return 0.5 * (p_A_lower - p_B_upper) - delta

    return f


def _cohen_cert(
    num_dim: int,
    num_mask: int,
    counts: Sequence[int],
    threshold: Optional[float] = None,
    alpha: float = 0.05,
) -> Callable[[int], float]:
    """Approximate certificate for randomized masking based on the statistical test presented in Section 3.2.2 of
    Cohen, Rosenfeld and Kolter (2019)

    Args:
        num_dim: Number of dimensions in the input.
        num_mask: Number of dimensions randomly masked in the input.
        counts: Class frequencies for randomly perturbed inputs passed through the classifier.

    Keyword args:
        alpha: Significance level. Defaults to 0.05.
        threshold: Classification threshold. Should be set to 0.5 for multiclass problems, but may be adjusted for
            two-class problems.

    Returns:
        A decreasing function of the L0 radius. If function is positive for a particular radius, then the certificate
        holds.
    """
    # One-sided lower bound on probabilities of most frequent class (\underbar{p_A})
    # By setting `method ="beta"` we are computing an exact Clopper-Pearson interval.
    max_counts = np.max(counts)
    num_samples = np.sum(counts)
    p_A_lower, _ = proportion_confint(
        max_counts, num_samples, alpha=2 * alpha, method="beta"
    )

    def f(radius: int) -> float:
        delta = 1.0 - np.exp(
            combln(num_dim - radius, num_mask - radius) - combln(num_dim, num_mask)
        )
        return p_A_lower - threshold - delta

    return f


def _jia_cert(
    num_dim: int,
    num_mask: int,
    counts: Sequence[int],
    threshold: Optional[float] = None,
    alpha: float = 0.05,
) -> Callable[[int], float]:
    """Approximate certificate for randomized masking based on the method proposed by Jia et al. (2022)

    Args:
        num_dim: Number of dimensions in the input.
        num_mask: Number of dimensions randomly masked in the input.
        counts: Class frequencies for randomly perturbed inputs passed through the classifier.

    Keyword args:
        alpha: Significance level. Defaults to 0.05.
        threshold: Classification threshold. Should be set to 0.5 for multiclass problems, but may be adjusted for
            two-class problems.

    Returns:
        A decreasing function of the L0 radius. If function is positive for a particular radius, then the certificate
        holds.
    """
    # We implement Jia et al.'s certificate for the 2-class setting, where their method for estimating bounds on the 
    # probability scores is the same as Cohen et al.'s. We therefore raise an exception if we're not in the 2-class 
    # setting. Note: if we want to extend to multi-class, we need to implement SimuEM (Jia et al., 2020).
    if len(counts) != 2:
        raise NotImplementedError("This certificate is not yet implemented for more than 2 classes")
    
    # One-sided lower bound on probability of most frequent class (\underbar{p_A}). 
    max_counts = np.max(counts)
    num_samples = np.sum(counts)
    p_A_lower, _ = proportion_confint(
        max_counts, num_samples, alpha=2 * alpha, method="beta"
    )

    # Adjust the lower bound by rounding up to the nearest integer multiple of q = 1 / comb(num_dim, num_mask).
    # We only need to do the rounding if q >= machine epsilon, otherwise it has no impact due to floating point 
    # quantization. 
    machine_eps = np.finfo(float).eps
    if combln(num_dim, num_mask) <= -np.log(machine_eps):
        # Use arbitrary precision integer arithmetic
        q_inv = comb(num_dim, num_dim - num_mask)
        # Integer representation of lower bound
        num, den = p_A_lower.as_integer_ratio()
        # Below is equivalent to p_A_lower = np.ceil(p_A_lower * q_inv) / q_inv. By adding `den - 1`, we perform 
        # integer division where the result is rounded up, rather than down.
        p_A_lower = ((q_inv * num + den - 1) // den) / q_inv

    def f(radius: int) -> float:
        delta = 1.0 - np.exp(
            combln(num_dim - radius, num_mask - radius) - combln(num_dim, num_mask)
        )
        return p_A_lower - threshold - delta

    return f


def sample_mask(
    binaries: torch.Tensor,
    mask_fraction: float,
    byte_chunks: Optional[torch.LongTensor] = None,
    min_chunks: Optional[int] = None,
    pad_value: Optional[torch.Tensor] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
    """Sample bytes to mask

    Args:
        binaries: A padded batch of binaries. The 1st dimension indexes binaries in the batch, and the 2nd
            dimension indexes byte addresses. The byte values may be represented as embedding vectors, in 
            which case, the 3rd dimension corresponds to the dimension of the embedding space.
        mask_fraction: Fraction of bytes (or chunk of bytes if `byte_chunks` is specified) to mask.

    Keyword args:
        byte_chunks: Specifies chunks of bytes that correspond to a single semantic object---e.g. an instruction.
            Must be a tensor with the same dimensions as the one used to represent the padded batch of binaries. The
            1st dimension indexes binaries in the batch, and the 2nd dimension indexes byte addresses. A contiguous
            chunk of bytes is labeled by a positive integer id (unique within each binary). Individual bytes that are
            not part of a specific chunk may be labeled zero (the same token used to represent padding). Within each
            binary, the chunk ids must appear in ascending order. For instance, [0, 1, 1, 0, 2, 3] is valid but
            [0, 2, 2, 0, 1, 3] is not.
        min_chunks: Minimum number of byte chunks to retain post-masking.
        pad_value: Value in `binaries` used to represent padding.

    Returns:
        A tuple (pair) of tensors that indexes byte addresses in `binaries` to mask.
    """
    num_binaries = binaries.size(0)
    length = binaries.size(1)
    # Reshape to 3d (flatten or unsqueeze)
    binaries = binaries.reshape(num_binaries, length, -1)
    if pad_value is None:
        pad_value = torch.zeros(
            binaries.size(2), dtype=binaries.dtype, device=binaries.device
        )
    else:
        pad_value = pad_value.flatten()

    if byte_chunks is not None:
        if byte_chunks.dim() != 2:
            raise ValueError("`byte_chunks` must be a 2-dimensional tensor")

        if byte_chunks.is_sparse:
            byte_chunks = byte_chunks.to_dense()

        # First task is to label the unlabeled byte chunks - i.e. assign a unique id to each zero in `byte_chunks`
        # that doesn't correspond to padding.

        # Number of *labeled* and *unlabeled* byte chunks in each binary
        num_labeled_chunks = torch.amax(byte_chunks, 1)
        unlabeled_mask = torch.logical_and(
            byte_chunks == 0, (binaries != pad_value).any(dim=-1)
        )
        num_unlabeled_chunks = unlabeled_mask.sum(dim=1)

        num_chunks = num_labeled_chunks + num_unlabeled_chunks

        # Adjust labeled byte chunks so that the ids are unique across binaries. E.g.
        # [[0, 1, 1, 0, 2], [0, 0, 0, 1, 1]] will become [[0, 1, 1, 0, 2], [0, 0, 0, 3, 3]]
        offsets = torch.zeros(
            (num_binaries, 1), dtype=torch.int64, device=byte_chunks.device
        )
        offsets[1:, 0] = num_chunks.cumsum(0)[:-1]
        byte_chunks = torch.where(byte_chunks != 0, byte_chunks + offsets, byte_chunks)

        byte_chunks[unlabeled_mask] = torch.arange(
            num_unlabeled_chunks.sum(), device=byte_chunks.device, dtype=torch.int64
        )
        byte_chunks = torch.where(
            unlabeled_mask,
            byte_chunks + num_labeled_chunks.cumsum(0)[:, None] + 1,
            byte_chunks,
        )
    else:
        pad_mask = (binaries != pad_value).any(dim=-1)
        num_chunks = torch.sum(pad_mask, dim=1)

        offsets = torch.zeros(
            (num_binaries, 1), dtype=torch.int64, device=binaries.device
        )
        offsets[1:, 0] = num_chunks.cumsum(0)[:-1]

        byte_chunks = torch.zeros_like([num_binaries, length], dtype=torch.int64, device=binaries.device)
        byte_chunks[pad_mask] = torch.arange(
            1, num_chunks.sum() + 1, dtype=torch.int64, device=byte_chunks.device
        )

    # Number of byte chunks to mask in each binary is a fraction of the total, and should not exceed
    # num_chunks - min_chunks
    num_mask = torch.ceil(mask_fraction * num_chunks).type(torch.int64)
    if min_chunks is not None and min_chunks > 0:
        min_chunks = torch.tensor(min_chunks, dtype=torch.int64, device=binaries.device)
        num_mask = torch.minimum(
            num_mask,
            torch.maximum(num_chunks - min_chunks, torch.zeros_like(num_chunks, device=binaries.device)),
        )

    # Sample chunks to mask. Since multinomial is not vectorized in `num_samples``, select `max_num_mask`
    # chunks to mask for all binaries. In general, this will yield too many chunks for some of the binaries, but
    # we discard them later.
    # Note: need to shift by 1 since chunks are labeled starting at 1.
    max_num_chunks = torch.max(num_chunks)
    max_num_mask = torch.max(num_mask)
    weights = 1.0 * (
        torch.arange(max_num_chunks, device=max_num_chunks.device) < num_chunks[:, None]
    )
    insn_id_mask = torch.multinomial(weights, max_num_mask, replacement=False) + 1

    # Relabel chunks ids to match those in byte_chunks
    insn_id_mask += offsets

    # Deal with the fact that we may have sampled too many instructions to mask for some of the binaries
    mask = torch.arange(max_num_mask, device=max_num_mask.device) < num_mask[:, None]
    insn_id_mask = insn_id_mask[mask]

    # Get indices of tensor entries to mask
    addr = torch.isin(byte_chunks, insn_id_mask).nonzero(as_tuple=True)

    return addr


def apply_mask(
    binaries: torch.Tensor,
    indices: Tuple[torch.LongTensor, torch.LongTensor],
    mask_value: torch.Tensor,
) -> torch.Tensor:
    """Apply byte masks to a binaries

    Args:
        binaries: A padded batch of binaries. The 1st dimension indexes binaries in the batch, and the 2nd
            dimension indexes byte addresses. The byte values may be represented as embedding vectors, in 
            which case, the 3rd dimension corresponds to the dimension of the embedding space.
        indices: A tuple of tensors that indexes byte addresses in `binaries` to mask.
        mask_value: Fill value for masked entries.

    Returns:
        The masked binaries
    """
    return binaries.index_put(indices, mask_value)


class MaskingMech(
    RandomPerturbation[
        Union[IntBinarySample, EmbedBinarySample], Union[IntBinary, EmbedBinary]
    ]
):
    """Masking randomized smoothing mechanism"""

    def __init__(
        self,
        mask_fraction: float,
        mask_value: torch.Tensor,
        pad_value: Optional[torch.Tensor] = None,
        group_insn: bool = False,
        threshold: Optional[float] = None,
        min_chunks: int = 500,
    ) -> None:
        """
        Args:
            mask_fraction: Fraction of instruction bytes (or chunks of bytes) to mask.
            mask_value: Value used to represent a masked byte (or chunk of bytes).
                During training time, if embedding is used, then this should be the weight associated with masked 
                values in the embedding layer, so it will be updated automatically.
            pad_value: Value used to represent padding.
            group_insn: If True, instructions are treated as chunks of bytes.
            threshold: Probability threshold for predicting class index 1 in a two-class problem. If not specified,
                defaults to 0.5.
            min_chunks: Minimum number of bytes (or chunks of bytes) to retain post-deletion. This parameter is
                only effective in training mode.
        """
        super().__init__(threshold=threshold)
        self.register_buffer("mask_fraction", torch.tensor(mask_fraction))
        self.register_buffer("mask_value", mask_value)
        self.register_buffer("pad_value", pad_value)
        self.register_buffer("group_insn", torch.tensor(group_insn))
        self.register_buffer("min_chunks", torch.tensor(min_chunks))

    def forward(
        self, input: Union[IntBinarySample, EmbedBinarySample]
    ) -> Union[IntBinary, EmbedBinary]:
        binaries, metadata = input
        insn_addr_ranges = metadata.get("insn_addr", None) if self.group_insn else None

        # If inputs does not have a batch dim, expand it
        if isinstance(binaries, IntBinary):
            binaries = torch.atleast_2d(binaries)
        if isinstance(binaries, EmbedBinary):
            binaries = torch.atleast_3d(binaries)

        mask_indices = sample_mask(
            binaries,
            self.mask_fraction,
            byte_chunks=insn_addr_ranges,
            min_chunks=self.min_chunks if self.training else 0,
            pad_value=self.pad_value,
        )
        binaries = apply_mask(binaries, mask_indices, mask_value=self.mask_value)

        return binaries

    def certified_radius(
        self,
        input: IntBinarySample,
        pred: int,
        counts: Sequence[int],
        alpha: float = 0.05,
        stat_test: str = "cohen",
        strategy: str = "binary_search",
        **kwargs,
    ) -> float:
        """Compute the certified radius for inputs to a classifier smoothed under this perturbation

        Args:
            input: Unperturbed binary sample. It must contain metadata with an entry for 'insn_addr'.
            pred: Estimated prediction of the smoothed classifier for `input`. Must be a class index in the set
                {0, 1, 2, ..., n_classes - 1}.
            counts: Class frequencies for randomly perturbed inputs passed through the classifier. Must be a sequence
                where `counts[i]` is the number of perturbed inputs with class index `i`.

        Keyword args:
            alpha: Significance level. Defaults to 0.05.
            stat_test: Statistical test used to compute the certificate. If "lecuyer" the test is based on
                Proposition 2 of Lecuyer et al. (2019). If "cohen" the test is based on Section 3.2.2 of Cohen,
                Rosenfeld and Kolter (2019).

        Returns:
            Return certified radius for this sample.
        """
        binary, metadata = input
        if self.pad_value is None or len(self.pad_value.size()) > 0:
            pad_value = torch.tensor(0, dtype=torch.int32, device=binary.device)

        insn_addr_ranges = metadata.get("insn_addr", None) if self.group_insn else None

        valid_stat_tests = {"cohen", "lecuyer", "jia"}
        if not stat_test in valid_stat_tests:
            raise ValueError(
                "`stat_test = {}` is not one of the permitted values {}".format(
                    stat_test, valid_stat_tests
                )
            )
        num_bytes = (binary != pad_value).sum()

        if self.group_insn:
            if insn_addr_ranges.layout == torch.sparse_coo:
                num_insn = insn_addr_ranges._values().max()
                num_insn_bytes = insn_addr_ranges._values().size(0)
            else:
                binary = F.to_tensor(binary, dtype=torch.int32)
                num_insn = torch.max(insn_addr_ranges)
                num_insn_bytes = torch.count_nonzero(
                    torch.logical_and(insn_addr_ranges != 0, binary != self.pad_value)
                )
            num_chunks = num_insn + (num_bytes - num_insn_bytes)
        else:
            num_chunks = num_bytes

        num_mask = ceil(self.mask_fraction * num_chunks)

        # Handle file with no instructions
        if num_mask == 0:
            return nan

        if isinstance(counts, torch.Tensor):
            counts = counts.cpu().numpy()

        threshold = self.threshold if self.threshold is not None else torch.tensor(0.5)
        if stat_test == "lecuyer":
            if self.threshold != 0.5:
                raise ValueError("lecuyer stat_test cannot be used if threshold != 0.5")
            f = _lecuyer_cert(num_chunks, num_mask, counts, alpha=alpha)
        elif stat_test == "cohen":
            f = _cohen_cert(
                num_chunks,
                num_mask,
                counts,
                threshold=threshold.detach().cpu().numpy(),
                alpha=alpha,
            )
        else:
            f = _jia_cert(
                num_chunks,
                num_mask,
                counts,
                threshold=threshold.detach().cpu().numpy(),
                alpha=alpha,
            )
        if strategy == "brute_force":
            largest_radius = brute_force_solve(f, x_max=num_mask)
        elif strategy == "binary_search":
            largest_radius = binary_search_solve(f, x_max=num_mask)
        else:
            raise ValueError(f"Unknown search strategy: {strategy}")

        return float(largest_radius)

    def extra_dim(self):
        return 1

    def __repr__(self):
        return (
            self.__class__.__name__
            + f"(mask_fraction={self.mask_fraction}, mask_value={self.mask_value}, pad_value={self.pad_value}, group_insn={self.group_insn}, threshold={self.threshold}, min_chunks={self.min_chunks})"
        )

    def load_state_dict(
        self, state_dict: OrderedDict[str, torch.Tensor], strict: bool = True
    ):
        for k, v in state_dict.items():
            if k == "mask_value" and isinstance(v, torch.Tensor) and v.nelement > 1:
                print(
                    "Warning: Loading a embedding dimension mask value. Be careful about the derivatives. "
                    "You can obtain a proper derivative if you load the embedding layer's weight here so they are associated."
                )
        return super().load_state_dict(state_dict, strict)
