import numpy as np
import torch
import os
import pandas as pd
from secml.data.loader import CDataLoader
from secml.data import CDataset, CDatasetHeader
from secml.utils import fm
from secml.utils.download_utils import dl_file, md5
from secml.settings import SECML_DS_DIR

from torch.utils.data import Dataset
from secml_malware.models.basee2e import End2EndModel

from secml.settings import SECML_PYTORCH_USE_CUDA

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()


class MyDataSet(Dataset):
    def __init__(self, root_dir, file_path, max_length, ablation_idx, num_ablations, no_samples):
        self.root_dir = root_dir
        self.file_path = file_path
        self.max_length = max_length
        self.ablation_idx = ablation_idx
        self.num_ablations = num_ablations

        list_names = []
        labels = []

        df = pd.read_csv(file_path, header=None)
        choice = self.file_path_finder(str(df.iloc[0, 0]))
        if choice == 'None': print("Check your file path!")

        for i in range(len(df)):
            if i == no_samples: break
            # print(df.iloc[i, 0], df.iloc[i, 1])
            file_name = str(df.iloc[i, 0])
            label = int(df.iloc[i, 1])

            if choice == 'lucas':
                first_folder = file_name[0:2]
                second_folder = file_name[2:4]
                final_file_name = root_dir + "/" + first_folder + "/" + second_folder + "/" + file_name
            elif choice == 'smk':
                final_file_name = root_dir + "/" + file_name

            list_names.append(final_file_name)
            #print(final_file_name)
            labels.append([label])

        """
        for i, f in enumerate(os.listdir(malware_path)):
            if (malware_path == None): break
            # print(f)
            if i == no_samples: break

            path = os.path.join(malware_path, f)
            list_names.append(path)
            labels.append([1.0])

        for i, f in enumerate(os.listdir(benign_path)):
            if (benign_path == None): break
            # print(f)
            if i == no_samples: break

            path = os.path.join(benign_path, f)
            list_names.append(path)
            labels.append([0.0])
        """

        self.list_names = list_names
        self.labels = labels

    def __len__(self):
        # this should return the size of the dataset
        return len(self.list_names)

    def __getitem__(self, idx):
        # this should return one sample from the dataset
        file_path = self.list_names[idx]

        if os.path.exists(file_path) == False:
            print(file_path, " does not exist.")

        with open(file_path, "rb") as file_handle:
            code = file_handle.read()
        X = End2EndModel.bytes_to_numpy(
            code, self.max_length, 256, False
        )

        # self.length = len(code)

        if use_mps: X = X.astype('float32')

        cutoff_idx = self.max_length
        while (cutoff_idx % self.num_ablations != 0):
            cutoff_idx -= 1
        ablation_size = int(cutoff_idx / self.num_ablations)
        X = X[:cutoff_idx]
        X = X[(ablation_size * self.ablation_idx): (ablation_size * (self.ablation_idx + 1))]

        y = self.labels[idx]
        y = np.asarray(y)
        y = y.astype('float32')
        # if use_mps: y = y.astype('float32')

        return X, y, len(code)

    def file_path_finder(self, file_name):
        first_folder = file_name[0:2]
        second_folder = file_name[2:4]
        final_file_name_lucas = self.root_dir + "/" + first_folder + "/" + second_folder + "/" + file_name
        final_file_name_smk = self.root_dir + "/" + file_name

        if os.path.exists(final_file_name_lucas):
            choice = 'lucas'
        elif os.path.exists(final_file_name_smk):
            choice = 'smk'
        else:
            choice = 'None'

        return choice
