import abc
from typing import Optional

# Modified from https://github.com/endgameinc/malware_evasion_competition
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import conv1d_l_out, ragged_len, window_pad


class Classifier(nn.Module, abc.ABC, object):
    def __init__(
        self,
        out_size: int = 2,
    ):
        super().__init__()
        self.out_size = out_size

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """Predict the class of the batch of input data

        Args:
            x (torch.Tensor): The input data

        Returns:
            torch.Tensor: Categorical output of the classes the inputs are classified as
        """
        probs = self.predict_proba(x)
        return self.proba_reduce(probs)

    def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
        """Get the predicted probabilities for each class after softmax

        Args:
            x (torch.Tensor): The input data

        Returns:
            torch.Tensor: Matrix representing the probability of each class
        """
        logits = self.forward(x)
        proba = torch.softmax(logits, 1)
        return proba

    @abc.abstractmethod
    def proba_reduce(self, proba: torch.Tensor, counts: Optional[int] = None) -> torch.Tensor:
        """Aggregate the predicted probabilities into classes

        Args:
            proba (torch.Tensor): Output of predict_proba, probabilities of each class
            counts: 

        Returns:
            torch.Tensor: Categorical output of the classses
        """
        pass


class EmbeddingClassifier(Classifier):
    """This class defines model whose first layer is embedding layer"""

    def __init__(
        self,
        out_size: int = 2,
        embed_num: int = 257,
        embed_size: int = 8,
        scale_grad_by_freq: bool = False,
    ):
        super().__init__(out_size=out_size)
        self.embed_num = embed_num
        self.embed_size = embed_size
        self.embed_1 = nn.Embedding(
            embed_num, embed_size, padding_idx=0, scale_grad_by_freq=scale_grad_by_freq
        )

    def embed(self, x: torch.IntTensor) -> torch.Tensor:
        """Get the embedding of an input IntTensor

        Args:
            x (torch.IntTensor): The input data as IntTensor

        Returns:
            torch.Tensor: Matrix representing the embedded data
        """
        return self.embed_1(x)

    @abc.abstractmethod
    def embedd_and_forward(
        self,
        x: torch.Tensor,
        return_logits: Optional[bool] = True,
    ) -> torch.Tensor:
        """Compute logits/probabilities from embeddings.

        Args:
            x (torch.Tensor): The input data
            return_logits (Optional[bool], optional): Whether or not to return logit or probabilities,
                default to return logits.

        Returns:
            torch.Tensor: Matrix representing the logit/probabilities of each class
        """
        pass

    def forward(
        self,
        x: torch.IntTensor,
        return_logits: Optional[bool] = True,
    ) -> torch.Tensor:
        """Compute logits from embeddings.

        Args:
            x (torch.Tensor): The input byte tensor
            return_logits (Optional[bool], optional): If the return format should be in logits or probabilities.

        Returns:
            torch.Tensor: The predicted result
        """
        # Each sample in x is a sequence of integer tokens in the set {1, ..., embed_num - 1} padded with zeroes
        x = x.int()
        x = window_pad(x, self.window_size)
        x = self.embed(x)
        out = self.embedd_and_forward(x, return_logits=return_logits)
        return out


class MalConv(EmbeddingClassifier):
    # trained to minimize cross-entropy loss
    # criterion = nn.CrossEntropyLoss()
    def __init__(
        self,
        out_size: int = 2,
        channels: int = 128,
        window_size: int = 512,
        embed_num: int = 257,
        embed_size: int = 8,
        scale_grad_by_freq: bool = False,
        threshold: Optional[torch.Tensor] = None,
    ):
        super().__init__(
            out_size=out_size,
            embed_num=embed_num,
            embed_size=embed_size,
            scale_grad_by_freq=scale_grad_by_freq,
        )
        self.channels = channels
        self.window_size = window_size

        self.conv_1 = nn.Conv1d(
            embed_size, channels, window_size, stride=window_size, bias=True
        )
        self.conv_2 = nn.Conv1d(
            embed_size, channels, window_size, stride=window_size, bias=True
        )
        self.pooling = nn.AdaptiveMaxPool1d(1)
        self.fc_1 = nn.Linear(channels, channels)
        self.fc_2 = nn.Linear(channels, out_size)

        self.register_buffer("threshold", threshold)

    def embedd_and_forward(
        self,
        x: torch.Tensor,
        l_in: Optional[torch.LongTensor] = None,
        return_logits: Optional[bool] = True,
    ) -> torch.Tensor:
        """Compute logits/probabilities from embeddings.

        Args:
            x (torch.Tensor): The input data
            l_in (Optional[torch.LongTensor]): ragged length of each input. Can be None,
                in which case it will be inferred (slighly slower)
            return_logits (Optional[bool], optional): Whether or not to return logit or probabilities,
                default to return logits.

        Returns:
            torch.Tensor: Matrix representing the logit/probabilities of each class
        """
        if l_in is None:
            l_in = ragged_len(x)
            l_in = torch.ceil(l_in / self.window_size).int() * self.window_size

        x = torch.transpose(x, -1, -2)

        cnn_value = self.conv_1(x)
        gating_weight = torch.sigmoid(self.conv_2(x))
        x = cnn_value * gating_weight

        l_out = x.size(2)
        # If the Conv1d layers were compatible with ragged tensors, L_out would vary for each sample in the batch,
        # as computed below.
        correct_l_out = conv1d_l_out(l_in, self.window_size, stride=self.window_size)

        # To produce the correct result, we need to ignore any extraneous dimensions (which correspond to padding)
        # when pooling.
        # Mask of shape (N, L_out)
        mask = torch.arange(l_out, device=x.device) >= correct_l_out.unsqueeze(1)
        # Repeat across channels to produce mask of shape (N, C_out, L_out)
        mask = mask.unsqueeze(1).expand(x.size())
        # Set extraneous dimensions to neg inf, since they'll be ignored when max pooling
        x[mask] = -torch.inf
        x = self.pooling(x)

        # Flatten
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        if return_logits:
            x = F.softmax(x)
        return x

    def forward(
        self,
        x: torch.IntTensor,
        return_logits: Optional[bool] = True,
    ) -> torch.Tensor:
        """Compute logits from embeddings.

        Args:
            x (torch.Tensor): The input byte tensor
            return_logits (Optional[bool], optional): If the return format should be in logits or probabilities.

        Returns:
            torch.Tensor: The predicted result
        """
        # Each sample in x is a sequence of integer tokens in the set {1, ..., embed_num - 1} padded with zeroes
        x = x.int()
        x = window_pad(x, self.window_size)
        # Length of each sample excluding padding
        l_in = ragged_len(x)
        # If length is not an integer multiple of the window_size, need to leave some padding for compatibility with
        # convolution. Round *up* to nearest multiple of window_size
        l_in = torch.ceil(l_in / self.window_size).int() * self.window_size
        x = self.embed(x)
        return self.embedd_and_forward(x, l_in, return_logits=return_logits)

    def proba_reduce(self, probs):
        p_mals = probs[:, 1]
        return torch.where(p_mals > self.threshold, 1, 0)
