import torch
from torch.autograd import Variable
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import copy

from secml_malware.models.malconv import MalConv
from secml_malware.models.c_classifier_end2end_malware import CClassifierEnd2EndMalware
from secml_malware.models.basee2e import End2EndModel
from secml.settings import SECML_PYTORCH_USE_CUDA

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()

class Custom_MalConv(End2EndModel):
    def __init__(self, ablation_idx=0, embedding_size=8, max_input_size=2 ** 18, unfreeze=True, batch_size=16):
        super(Custom_MalConv, self).__init__(embedding_size, max_input_size, 256, False)

        net = MalConv()
        net = CClassifierEnd2EndMalware(net)
        net.load_pretrained_model()
        net2 = copy.deepcopy(net)

        self.ablation_idx = ablation_idx
        self.max_input_size = max_input_size
        self.embedding_1 = net2._model.embedding_1
        self.conv1d_1 = net2._model.conv1d_1
        self.conv1d_2 = net2._model.conv1d_2

        self.unfreeze_training_until_fc(unfreeze)

        fc_layer_input_shape = self.get_fc_layer_input_shape(net2)

        self.classifier = nn.Sequential(nn.Linear(in_features=fc_layer_input_shape, out_features=128, bias=True),
                                        nn.ReLU(True),
                                        nn.Linear(in_features=128, out_features=1, bias=True),
                                        nn.Sigmoid())

        self.init_classifier_weights()
        if use_cuda:
            self.cuda()
            print("Using CUDA")
        elif use_mps:
            self.to(torch.device('mps'))
            print("Using MPS")

    def get_fc_layer_input_shape(self, net):
        malconv_model = net._model
        test_ipt = Variable(torch.zeros(1, 2 ** 18))
        embedding = net.embed(test_ipt)
        conv1d_1 = malconv_model.conv1d_1(embedding)
        conv1d_2 = malconv_model.conv1d_2(embedding)
        conv1d_1_activation = torch.relu(conv1d_1)
        conv1d_2_activation = torch.sigmoid(conv1d_2)
        multiply_1 = conv1d_1_activation * conv1d_2_activation
        global_max_pooling1d_1 = F.max_pool1d(input=multiply_1, kernel_size=multiply_1.size()[2:])
        global_max_pooling1d_1_flatten = global_max_pooling1d_1.view(global_max_pooling1d_1.size(0), -1)

        return global_max_pooling1d_1_flatten.size(1)

    def init_classifier_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def unfreeze_training_until_fc(self, unfreeze):
        self.embedding_1.weight.requires_grad = unfreeze
        # self.embedding_1.bias.requires_grad = unfreeze
        self.conv1d_1.weight.requires_grad = unfreeze
        self.conv1d_1.bias.requires_grad = unfreeze
        self.conv1d_2.weight.requires_grad = unfreeze
        self.conv1d_2.bias.requires_grad = unfreeze

    def embed(self, input_x, transpose=True):
        if isinstance(input_x, torch.Tensor):
            x = input_x.clone().detach().requires_grad_(True).type(torch.LongTensor)
        else:
            x = torch.from_numpy(input_x).type(torch.LongTensor)
        x = x.squeeze(dim=1)
        if use_cuda:
            x = x.cuda()
        elif use_mps:
            # x = torch.tensor(x, device='mps')
            x = x.to(torch.device('mps'))
        emb_x = self.embedding_1(x)
        if transpose:
            emb_x = torch.transpose(emb_x, 1, 2)
        return emb_x

    def embedd_and_forward(self, x):
        conv1d_1 = self.conv1d_1(x)
        conv1d_2 = self.conv1d_2(x)
        conv1d_1_activation = torch.relu(conv1d_1)
        conv1d_2_activation = torch.sigmoid(conv1d_2)
        multiply_1 = conv1d_1_activation * conv1d_2_activation
        global_max_pooling1d_1 = F.max_pool1d(input=multiply_1, kernel_size=multiply_1.size()[2:])
        global_max_pooling1d_1_flatten = global_max_pooling1d_1.view(global_max_pooling1d_1.size(0), -1)
        out = self.classifier(global_max_pooling1d_1_flatten)
        return out

    def predict(self, x):
        embedding = self.embed(x)
        predict_score = self.embedd_and_forward(embedding)
        return predict_score