import numpy as np
import torch
import os
from secml.data.loader import CDataLoader
from secml.data import CDataset, CDatasetHeader
from secml.utils import fm
from secml.utils.download_utils import dl_file, md5
from secml.settings import SECML_DS_DIR

from torch.utils.data import Dataset
from secml_malware.models.basee2e import End2EndModel

from secml.settings import SECML_PYTORCH_USE_CUDA
use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()

class MyDataSet(Dataset):
    def __init__(self, malware_path, benign_path, max_length, ablation_idx, num_ablations, no_samples):
        self.malware_path = malware_path
        self.benign_path = benign_path
        self.max_length = max_length
        self.ablation_idx = ablation_idx
        self.num_ablations = num_ablations

        list_names = []
        labels = []

        for i, f in enumerate(os.listdir(malware_path)):
            if (malware_path == None): break
            # print(f)
            if i == no_samples: break

            path = os.path.join(malware_path, f)
            list_names.append(path)
            labels.append([1.0])

        for i, f in enumerate(os.listdir(benign_path)):
            if (benign_path == None): break
            # print(f)
            if i == no_samples: break

            path = os.path.join(benign_path, f)
            list_names.append(path)
            labels.append([0.0])

        self.list_names = list_names
        self.labels = labels

    def __len__(self):
        # this should return the size of the dataset
        return len(self.list_names)

    def __getitem__(self, idx):
        # this should return one sample from the dataset
        file_path = self.list_names[idx]

        with open(file_path, "rb") as file_handle:
            code = file_handle.read()
        X = End2EndModel.bytes_to_numpy(
            code, self.max_length, 256, False
        )

        #self.length = len(code)

        if use_mps: X = X.astype('float32')

        cutoff_idx = self.max_length
        while (cutoff_idx % self.num_ablations != 0):
            cutoff_idx -= 1
        ablation_size = int(cutoff_idx / self.num_ablations)
        X = X[:cutoff_idx]
        X = X[(ablation_size * self.ablation_idx): (ablation_size * (self.ablation_idx + 1))]

        y = self.labels[idx]
        y = np.asarray(y)
        y = y.astype('float32')
        #if use_mps: y = y.astype('float32')

        return X, y, len(code)


