import torch

import numpy as np
import pandas as pd
from uci import UCI
from toy import Toy
from sklearn import preprocessing

from torch.utils.data import Dataset
from extract_metafeatures import extract

class D2V(Dataset):

    def __init__(self, name_dataset, seed, test=False, only_handcrafted=False):
        self.name_dataset = name_dataset
        self.only_handcrafted = only_handcrafted
        if name_dataset == "uci":
            self.data = UCI(seed=seed)
        elif name_dataset == "toy":
            self.data = Toy(seed=seed)
        else:
            raise Exception("Dataset dot not exist.")
        self.test = False

        if self.only_handcrafted:
            if name_dataset == "uci":
                handcrafted_data = pd.read_csv("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/checkpoints/metafeatures_handcrafted_cc18.csv", index_col="index")
            elif name_dataset == "toy":
                handcrafted_data = pd.read_csv("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/checkpoints/metafeatures_handcrafted_toy.csv", index_col="index")
            self.list_cols_metafeatures = handcrafted_data.columns
            self.mf_scaler = preprocessing.MinMaxScaler()
            self.mf_scaler.fit(handcrafted_data)

    def __len__(self):
        if self.test:
            return 10000
        return 3200

    def get_size_dataset(self):
        if self.name_dataset == "toy":
            return 200, 2, 2
        else:
            Ns = np.random.randint(200, 500)
            Ms = np.random.randint(2, 15)
            Ls = np.random.randint(2, 10)
            return Ns, Ls, Ms

    def __getitem__(self, list_idx):
        Ns, Ls, Ms = self.get_size_dataset()

        list_X_1, list_Y_1, list_X_2, list_Y_2, I = self.data.get_batch(len(list_idx), Ns, Ls, Ms, self.test, stratification_pos_ratio = 0.5)

        assert Ns == list_X_1.shape[1]
        assert Ms == list_X_1.shape[2]
        assert Ls == list_Y_1.shape[2]

        if self.only_handcrafted:
            mf_1 = [extract(x, y.argmax(1), list_columns=self.list_cols_metafeatures) for x, y in zip(list_X_1, list_Y_1)]
            mf_1 = self.mf_scaler.transform(mf_1)
            mf_2 = [extract(x, y.argmax(1), list_columns=self.list_cols_metafeatures) for x, y in zip(list_X_2, list_Y_2)]
            mf_2 = self.mf_scaler.transform(mf_2)
            list_X_1, list_Y_1, list_X_2, list_Y_2 = mf_1, mf_1, mf_2, mf_2

        return torch.FloatTensor(list_X_1).float(), \
                torch.FloatTensor(list_Y_1).float(), \
                torch.FloatTensor(list_X_2).float(), \
                torch.FloatTensor(list_Y_2).float(), \
                torch.FloatTensor(I)
