import os
import magic
from secml.array import CArray
import numpy as np
import torch
import math

from secml_malware.models.malconv import MalConv
from secml_malware.models.c_classifier_end2end_malware import CClassifierEnd2EndMalware
from secml_malware.models.basee2e import End2EndModel

def get_dataset(malware_path, benign_path, max_length, no_samples=-1):
    """

    @param malware_path:
    @param benign_path:
    @param max_length:
    @param no_samples: The max number of files to read from directory and add to the dataset. If -1, then include all files, otherwise, just inlcude the specified the number of samples.
    @return:
    """
    X = []
    y = []
    file_names = []
    lengths = []
    for i, f in enumerate(os.listdir(malware_path)):
        if (malware_path == None): break
        # print(f)
        if i == no_samples: break

        path = os.path.join(malware_path, f)
        """
        if 'petya' not in path:
            continue
        if "PE32" not in magic.from_file(path):
            continue
        """
        with open(path, "rb") as file_handle:
            code = file_handle.read()
        x = End2EndModel.bytes_to_numpy(
            code, max_length, 256, False
        )
        # y_train = np.ones((x.shape[0]))
        # print(x, y_train)
        # print(x.shape)

        X.append(x)
        y.append([1.0])
        file_names.append(path)
        lengths.append(len(code))

    for i, f in enumerate(os.listdir(benign_path)):
        if (benign_path == None): break
        # print(f)
        if i == no_samples: break

        path = os.path.join(benign_path, f)
        """
        if 'petya' not in path:
            continue
        if "PE32" not in magic.from_file(path):
            continue
        """
        with open(path, "rb") as file_handle:
            code = file_handle.read()
        x = End2EndModel.bytes_to_numpy(
            code, max_length, 256, False
        )
        # y_train = np.ones((x.shape[0]))
        # print(x, y_train)
        # print(x.shape)

        X.append(x)
        y.append([0.0])
        file_names.append(path)
        lengths.append(len(code))

    return X, y, file_names, lengths


def create_smoothed_malconv(num_ablations, ablation_size):
    # ablation_size = int(max_input_size/num_ablations)
    nets = []
    for i in range(num_ablations):
        net = MalConv(max_input_size=ablation_size)
        net = CClassifierEnd2EndMalware(net, input_shape=(1, ablation_size))
        net._n_features = ablation_size
        nets.append(net)
    # print(nets)
    return nets


def modify_dataset_for_smoothed_malconv(X, y, num_ablations):
    new_X = []
    new_Y = []
    cutoff_idx = len(X[0])
    while (cutoff_idx % num_ablations != 0):
        cutoff_idx -= 1
    for i, x in enumerate(X):
        splitted_x_array = np.split(x[:cutoff_idx], num_ablations)  ##here, -1 to make it divisible of 5
        new_X.append(splitted_x_array)
        new_Y.append(np.ones((num_ablations)) if y[i] == 1 else np.zeros((num_ablations)))
    new_X = np.asarray(new_X)
    new_Y = np.asarray(new_Y)
    # print(new_X.shape)
    # print(new_Y.shape)
    return new_X, new_Y


def modify_dataset_for_smoothed_malconv_by_ablation(X, y, num_ablations, ablation_idx):
    new_X = []
    cutoff_idx = len(X[0])
    while (cutoff_idx % num_ablations != 0):
        cutoff_idx -= 1
    ablation_size = int(cutoff_idx / num_ablations)

    for i, x in enumerate(X):
        x = x[:cutoff_idx]
        sub_x = x[(ablation_size * ablation_idx): (ablation_size * (ablation_idx + 1))]
        new_X.append(sub_x)
        # new_Y.append(y[i])
    new_X = np.asarray(new_X)
    new_Y = np.asarray(y)
    # print(new_X.shape)
    # print(new_Y.shape)
    return new_X, new_Y


def pad_ablated_input(inp, ablated_idx, max_len=2 ** 20, padding_val=256):
    temp = np.full((max_len,), padding_val, dtype=np.float32)
    ablation_size = inp.shape[0]
    temp[ablated_idx * ablation_size: (ablated_idx + 1) * ablation_size] = inp
    return temp


def train_model(nets, num_ablations, X, y, epochs):
    for i in range(num_ablations):
        nets[i].epochs = epochs
        nets[i]._model = nets[i].fit(X[:, i, :], y[:, i])
    return nets


def model_predict(nets, num_ablations, X):
    y_preds = []
    for i in range(num_ablations):
        y_pred = nets[i].predict(X[:, i, :])
        # print(y_pred)
        y_preds.append(y_pred.tondarray())
        # y_preds.append(y_pred)
    y_preds = np.asarray(y_preds)
    return y_preds


def get_majority_voting(y_detailed, sample_no):
    votes = []
    for i in range(sample_no):
        y_per_sample = y_detailed[:, i]
        mal_vote = np.count_nonzero(y_per_sample)
        ben_vote = np.count_nonzero(y_per_sample == 0)
        if (mal_vote >= ben_vote):
            votes.append(1)
        else:
            votes.append(0)
    return votes


def get_majority_voting_without_padding(y_detailed, sample_no, lengths, model_input_length, perturb_size):
    votes = []
    certified_votes = []
    for i in range(sample_no):
        code_length = lengths[i]
        max_ablation_idx = int((code_length / model_input_length) + 1)
        #max_ablation_idx = int((code_length / model_input_length))
        #if max_ablation_idx==0: max_ablation_idx = 1
        y_per_sample = y_detailed[:max_ablation_idx, i]
        # print(y_per_sample)
        mal_vote = np.count_nonzero(y_per_sample)
        ben_vote = np.count_nonzero(y_per_sample == 0)
        if (mal_vote >= ben_vote):
            votes.append(1)
        else:
            votes.append(0)

        delta = math.ceil(perturb_size/model_input_length) + 1
        if mal_vote >= (ben_vote + 2 * delta):
            certified_votes.append(1)
        elif ben_vote > (mal_vote + 2 * delta):
            certified_votes.append(0)
        else:
            certified_votes.append(-1)

    return votes, certified_votes
