from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
from statsmodels.stats.proportion import proportion_confint

from ..types import EmbedBinary, EmbedBinarySample, IntBinary, IntBinarySample
from .perturbation import RandomPerturbation


def _edit_cert(p_del: float, mu: float, eta: float = 0.5) -> float:
    """Approximate edit distance certificate when all types of edits are permitted.

    Args:
        p_del: probability of deleting a value at a location.
        mu: predicted probability
        eta: threshold for the prediction

    Returns:
        Radius of the certificate
    """
    radius = np.log(1 + eta - mu) / np.log(p_del)
    radius = np.floor(radius).item()
    return max(radius, -1)


def _disp_cert(p_del: float, mu: float, eta: float = 0.5) -> float:
    """Approximate disp distance certificate when all types of edits are permitted.

    Args:
        p_del: probability of deleting a value at a location.
        mu: predicted probability
        eta: threshold for the prediction

    Returns:
        Radius of the certificate
    """
    radius = 2 * np.log((eta + np.sqrt(4 + eta**2 - 4 * mu)) / 2) / np.log(p_del)
    radius = np.floor(radius).item()
    return max(radius, -1)


def _del_cert(p_del: float, mu: float, eta: float = 0.5) -> float:
    """Approximate deletion distance certificate when all types of edits are permitted.

    Args:
        p_del: probability of deleting a value at a location.
        mu: predicted probability
        eta: threshold for the prediction

    Returns:
        Radius of the certificate
    """
    radius = np.log(eta / mu) / np.log(p_del)
    radius = np.floor(radius).item()
    return max(radius, -1)


def _ins_cert(p_del: float, mu: float, eta: float = 0.5) -> float:
    """Approximate deletion insertion distance certificate when all types of edits are permitted.

    Args:
        p_del: probability of deleting a value at a location.
        mu: predicted probability
        eta: threshold for the prediction

    Returns:
        Radius of the certificate
    """
    radius = np.log((1 - mu) / (1 - eta)) / np.log(p_del)
    radius = np.floor(radius).item()
    return max(radius, -1)


def sample_del(
    binaries: torch.Tensor,
    p_del: float,
    byte_chunks: Optional[torch.LongTensor] = None,
    min_chunks: Optional[int] = None,
    pad_value: Optional[torch.Tensor] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
    """Sample bytes to delete

    Args:
        binaries: A padded batch of binaries. The 1st dimension indexes binaries in the batch, and the 2nd
            dimension indexes byte addresses. The byte values may be represented as embedding vectors, in 
            which case, the 3rd dimension corresponds to the dimension of the embedding space.
        p_del: Probability of deleting a byte (or chunk of bytes if `byte_chunks` is specified).

    Keyword args:
        byte_chunks: Specifies chunks of bytes that correspond to a single semantic object---e.g. an instruction.
            Must be a tensor with the same dimensions as the one used to represent the padded batch of binaries. The
            1st dimension indexes binaries in the batch, and the 2nd dimension indexes byte addresses. A contiguous
            chunk of bytes is labeled by a positive integer id (unique within each binary). Individual bytes that are
            not part of a specific chunk may be labeled zero (the same token used to represent padding). Within each
            binary, the chunk ids must appear in ascending order. For instance, [0, 1, 1, 0, 2, 3] is valid but
            [0, 2, 2, 0, 1, 3] is not.
        min_chunks: Minimum number of byte chunks to retain post-deletion.
        pad_value: Value in `binaries` used to represent padding. For IntBinary, it needs to be a Int32 type tensor.
            Otherwise it will be a float32 matching the embedding dimension. If None, then it is inferred by the input binaries.

    Returns:
        A tuple (pair) of tensors that indexes byte addresses in `binaries` to delete.
    """
    num_binaries = binaries.size(0)
    length = binaries.size(1)
    # Reshape to 3d (flatten or unsqueeze)
    binaries = binaries.reshape(num_binaries, length, -1)
    if pad_value is None:
        pad_value = torch.zeros(
            binaries.size(2), dtype=binaries.dtype, device=binaries.device
        )
    else:
        pad_value = pad_value.flatten()

    if byte_chunks is not None:
        if byte_chunks.dim() != 2:
            raise ValueError("`byte_chunks` must be a 2-dimensional tensor")

        if byte_chunks.is_sparse:
            byte_chunks = byte_chunks.to_dense()

        # Number of labeled byte chunks in each binary
        num_labeled_chunks = torch.amax(byte_chunks, 1)
        unlabeled_mask = torch.logical_and(
            byte_chunks == 0, torch.any(binaries != pad_value, dim=-1)
        )
        num_unlabeled_chunks = unlabeled_mask.sum(dim=1)

        num_chunks = num_labeled_chunks + num_unlabeled_chunks

        # Adjust labeled byte chunks so that the ids are unique across binaries. E.g.
        # [[0, 1, 1, 0, 2], [0, 0, 0, 1, 1]] will become [[0, 1, 1, 0, 2], [0, 0, 0, 3, 3]]
        # NOTE: The increments also includes the non-byte chunks
        offsets = torch.zeros(
            (num_binaries, 1), dtype=torch.int64, device=byte_chunks.device
        )
        offsets[1:, 0] = num_chunks.cumsum(0)[:-1]
        byte_chunks = torch.where(byte_chunks != 0, byte_chunks + offsets, byte_chunks)

        byte_chunks[unlabeled_mask] = torch.arange(
            num_unlabeled_chunks.sum(), device=byte_chunks.device, dtype=torch.int64
        )
        byte_chunks = torch.where(
            unlabeled_mask,
            byte_chunks + num_labeled_chunks.cumsum(0)[:, None] + 1,
            byte_chunks,
        )
    elif min_chunks is not None and min_chunks > 0:
        pad_mask = torch.any(binaries != pad_value, dim=-1)
        num_chunks = torch.sum(pad_mask, dim=1)

        offsets = torch.zeros(
            (num_binaries, 1), dtype=torch.int64, device=binaries.device
        )
        offsets[1:, 0] = num_chunks.cumsum(0)[:-1]

        byte_chunks = torch.zeros([num_binaries, length], dtype=torch.int64, device=binaries.device)
        byte_chunks[pad_mask] = torch.arange(
            1, num_chunks.sum() + 1, dtype=torch.int64, device=byte_chunks.device
        )
    else:
        addr = torch.nonzero(
            torch.rand([num_binaries, length], dtype=torch.float32, device=binaries.device) < p_del,
            as_tuple=True,
        )
        return addr

    max_chunk_id = num_chunks.sum()

    # Sample chunks to delete
    del_chunk_ids = (
        torch.nonzero(
            torch.rand(max_chunk_id, dtype=torch.float32, device=byte_chunks.device)
            < p_del
        )
        + 1
    )
    del_chunk_ids = del_chunk_ids.squeeze()

    if min_chunks is not None and min_chunks > 0:
        max_chunk_ids = num_chunks.cumsum(0)

        # Generate random permutation of del_chunk_ids
        perm_id = torch.randperm(del_chunk_ids.size(0), device=del_chunk_ids.device)

        # Sort perm_id so that entries corresponding to the same binary appear consecutively
        binary_idx = torch.bucketize(del_chunk_ids, max_chunk_ids)
        sort_perm_id = torch.argsort(binary_idx[perm_id])
        perm_id = perm_id[sort_perm_id]

        num_del_chunks = torch.bincount(binary_idx)
        # Number of chunks to delete after enforcing min_chunks constraint
        cor_num_del_chunks = torch.minimum(
            num_del_chunks,
            torch.maximum(num_chunks - min_chunks, torch.zeros_like(num_chunks)),
        )

        # Generate indices into `perm_id` that will select the corrected number of chunks for each binary
        idx = torch.arange(cor_num_del_chunks.sum(), device=byte_chunks.device)
        # Constant delta to add to the indices for each binary
        delta = torch.cat(
            (
                torch.zeros(1, dtype=torch.int64, device=byte_chunks.device),
                (num_del_chunks - cor_num_del_chunks).cumsum(0)[:-1],
            )
        )
        delta = torch.repeat_interleave(delta, cor_num_del_chunks)
        idx += delta

        del_chunk_ids = del_chunk_ids[perm_id[idx]]

    # Get indices of tensor entries to mask
    addr = torch.isin(byte_chunks, del_chunk_ids).nonzero(as_tuple=True)
    return addr


def apply_del(
    binaries: torch.Tensor,
    indices: Tuple[torch.LongTensor, torch.LongTensor],
    pad_value: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Delete bytes from a batch of binaries

    Args:
        binaries: A padded batch of binaries. The 1st dimension indexes binaries in the batch, and the 2nd
            dimension indexes byte addresses. The byte values may be represented as embedding vectors, in 
            which case, the 3rd dimension corresponds to the dimension of the embedding space.
        indices: A tuple of tensors that indexes byte addresses in `binaries` to delete.
        pad_value: Value used to represent padding in `binaries`. If None then infer the value

    Returns:
        Batch of binaries post-deletion
    """
    num_binaries = binaries.size(0)
    # Reshape to 3d (flatten or unsqueeze)
    embed_size = binaries.size()[2:]
    binaries = binaries.reshape(num_binaries, binaries.size(1), -1)
    # Infer default pad value
    if pad_value is None:
        pad_value = torch.zeros(
            binaries.size()[2:], dtype=binaries.dtype, device=binaries.device
        )
    # Flatten provided one instead
    else:
        pad_value = pad_value.flatten()

    binaries = binaries.index_put(indices, pad_value)
    # Shift pad values to the end of each row
    # Based on http://stackoverflow.com/a/42859463/3293881
    valid_mask = (binaries != pad_value).any(dim=-1)
    true_sizes = valid_mask.sum(dim=1, keepdims=True)
    flipped_mask = true_sizes > torch.arange(binaries.size(1), device=binaries.device)
    binaries[flipped_mask] = binaries[valid_mask]
    binaries[~flipped_mask] = pad_value
    ## Remove excess padding
    largest_size = torch.amax(true_sizes)
    binaries = binaries[:, :largest_size]

    # Final adjustment of size
    if len(embed_size) == 0:
        binaries = binaries.squeeze(-1)
    elif len(embed_size) > 1:
        binaries = binaries.reshape(num_binaries, largest_size, *embed_size)
    return binaries


# And inhert another deletion class that takes in embeddings.
class DeletionMech(
    RandomPerturbation[
        Union[IntBinarySample, EmbedBinarySample], Union[IntBinary, EmbedBinary]
    ]
):
    """Defines the deletion randomized smoothing mechanism."""

    def __init__(
        self,
        p_del: float,
        pad_value: Optional[torch.Tensor] = None,
        group_insn: bool = False,
        threshold: Optional[float] = None,
        min_chunks: int = 500,
    ) -> None:
        """
        Args:
            p_del: Probability of deleting a specific byte (or chunk of bytes)
            pad_value: Value used to represent padding.
            group_insn: If True, instructions are treated as chunks of bytes.
            threshold: Probability threshold for predicting class index 1 in a two-class problem. If not specified,
                defaults to 0.5.
            min_chunks: Minimum number of bytes (or chunks of bytes) to retain post-deletion. This parameter is
                only effective in training mode.
        """
        super().__init__(threshold=threshold)
        self.register_buffer("p_del", torch.tensor(p_del))
        self.register_buffer("pad_value", pad_value)
        self.register_buffer("group_insn", torch.tensor(group_insn))
        self.register_buffer("min_chunks", torch.tensor(min_chunks))

    def forward(
        self, input: Union[IntBinarySample, EmbedBinarySample]
    ) -> Union[IntBinary, EmbedBinary]:
        binaries, metadata = input
        insn_addr_ranges = metadata.get("insn_addr", None) if self.group_insn else None

        # If inputs does not have a batch dim, expand it
        if isinstance(binaries, IntBinary):
            binaries = torch.atleast_2d(binaries)
        if isinstance(binaries, EmbedBinary):
            binaries = torch.atleast_3d(binaries)

        del_indices = sample_del(
            binaries,
            self.p_del,
            byte_chunks=insn_addr_ranges,
            min_chunks=self.min_chunks if self.training else 0,
            pad_value=self.pad_value,
        )
        binaries = apply_del(binaries, del_indices, pad_value=self.pad_value)

        return binaries

    def certified_radius(
        self,
        input: Union[IntBinarySample, EmbedBinarySample],
        pred: int,
        counts: Sequence[int],
        alpha: float = 0.05,
        threat_model: str = "edit",
        **kwargs,
    ) -> float:
        """Compute the certified edit distance radius for inputs to a classifier smoothed under this perturbation

        Args:
            input: Unperturbed binary sample.
            pred: Estimated prediction of the smoothed classifier for `input`. Must be a class index in the set
                {0, 1, 2, ..., n_classes - 1}.
            counts: Class frequencies for randomly perturbed inputs passed through the classifier. Must be a sequence
                where `counts[i]` is the number of perturbed inputs with class index `i`.

        Keyword args:
            alpha: Significance level. Defaults to 0.05.
            threat_model: Threat model used to compute the certificate

        Returns:
            The certified radius for this sample.
        """
        if isinstance(counts, torch.Tensor):
            counts = counts.detach().cpu().numpy()
        num_classes = len(counts)
        threshold = self.threshold
        if num_classes > 2 and threshold is not None:
            raise ValueError("Only supports explicit threshold for 2 class problems")

        # One-sided lower bound on probabilities of most frequent class (\underbar{p_A})
        # By setting `method ="beta"` we are computing an exact Clopper-Pearson interval.
        num_samples = np.sum(counts)
        p_lower, _ = proportion_confint(
            counts[pred], num_samples, alpha=2 * alpha, method="beta"
        )

        # Thresh defaults to 0.5
        threshold = threshold if threshold is not None else torch.tensor(0.5)
        # Use 1 - threshold if binary and predicts 0
        if pred == 0 and threshold is not None:
            threshold = 1 - threshold

        valid_threat_models = {
            "edit",
            "ins",
            "disp",
            "del",
            "delins",
            "sub",
            "inssub",
            "delsub",
        }
        if not threat_model in valid_threat_models:
            raise ValueError(
                "`threat_model = {}` is not one of the permitted values {}".format(
                    threat_model, valid_threat_models
                )
            )

        if isinstance(counts, torch.Tensor):
            counts = counts.detach().cpu().numpy()
        p_del = self.p_del.cpu()

        if threat_model in ("edit", "sub", "inssub", "delsub"):
            radius = _edit_cert(p_del=p_del, mu=p_lower, eta=threshold)
        if threat_model in ("disp"):
            radius = _disp_cert(p_del=p_del, mu=p_lower, eta=threshold)
        elif threat_model in ("ins",):
            radius = _ins_cert(p_del=p_del, mu=p_lower, eta=threshold)
        elif threat_model in ("del", "delins"):
            radius = _del_cert(p_del=p_del, mu=p_lower, eta=threshold)
        return radius

    def extra_dim(self):
        return 1

    def __repr__(self):
        return (
            self.__class__.__name__
            + f"(p_del={self.p_del}, pad_value={self.pad_value}, group_insn={self.group_insn}, threshold={self.threshold}, min_chunks={self.min_chunks})"
        )
