import torch
import torch.utils.data as data
import numpy as np
import scipy.io as scio
import os
from sklearn.model_selection import train_test_split
import math
from tools.get_path import get_project_path


class Dataset():
    def __init__(self, datadir=get_project_path()+"MLC/Datasets", configs={}, nfold=3):
        self.datadir = datadir
        self.datafile = os.path.join(datadir, self.name(), self.name() + '.mat')

        # Load data
        self.dtype = configs['dtype']
        self.data_standardizing = configs['data_standardizing']
        self.eps = configs['eps']
        self._data(configs['split'], configs['shuffle'], configs['rand_seed'], nfold)
        self.rand_seed = configs['rand_seed']
        self.feat_dim = self.X.size(1)
        self.num_class = self.Y.size(1)


    def name(self):
        return self.__class__.__name__

    def _data(self, dataSplit, shuffle=False, random_state=None, nfold=3):
        self._data_loading()
        self._data_preprocess()
        self._data_split(dataSplit, shuffle, random_state, nfold)

    def _data_loading(self):
        data = scio.loadmat(self.datafile)
        self.X = torch.from_numpy(data['data'].astype(float)).type(self.dtype)
        self.Y = torch.from_numpy(data['target']).type(self.dtype)
        self.Y[self.Y < 0] = 0
        # self.Y[self.Y < 1] = -1
        # self.Y_test[self.Y_test < 1] = -1

    def _data_preprocess(self):
        # max_X = torch.max(self.X, dim=0, keepdim=True)[0]
        # min_X = torch.min(self.X, dim=0, keepdim=True)[0]
        # self.X = (self.X - min_X) / (max_X - min_X + self.eps)
        mu_train = torch.mean(self.X, dim=0)
        std_train = torch.std(self.X, dim=0)
        self.X = (self.X - mu_train) / std_train

        # mu_test = torch.mean(self.X_test, dim=0)
        # std_test = torch.std(self.X_test, dim=0)
        # self.X_test = (self.X_test - mu_test) / std_test

    def _data_split(self, dataSplit, shuffle=False, random_state=None, nfold=3):
        file_random = '{:d}'.format(random_state)
        for count in range(1, nfold + 1):
            sub_file = os.path.join(self.datadir, self.name(), self.name() + '_' + file_random + '_' + str(count) +'.mat')
            if not os.path.exists(sub_file):
                if count < nfold:
                    self.sub_x, self.re_x, self.sub_y, self.re_y = train_test_split(self.X, self.Y,
                                                                                    train_size=1 / (nfold - count + 1),
                                                                                    random_state=random_state,
                                                                                    shuffle=shuffle)
                    scio.savemat(sub_file, {'X': self.sub_x.numpy(), 'Y': self.sub_y.numpy()})
                    self.X = self.re_x
                    self.Y = self.re_y
                else:
                    scio.savemat(sub_file, {'X': self.X.numpy(), 'Y': self.Y.numpy()})

            else:
                break

    def cv(self, test_num, nfold):
        print(nfold)
        file_random = '{:d}'.format(self.rand_seed)
        first_train = False
        for count in range(1, nfold + 1):
            sub_file = os.path.join(self.datadir, self.name(),
                                    self.name() + '_' + file_random + '_' + str(count) + '.mat')
            if count == test_num:
                test_data = scio.loadmat(sub_file)
                self.X_test = torch.from_numpy(test_data['X'].astype(float)).type(self.dtype)
                self.Y_test = torch.from_numpy(test_data['Y']).type(self.dtype)
            else:
                train_data = scio.loadmat(sub_file)
                if first_train == False:
                    self.X_train = torch.from_numpy(train_data['X'].astype(float)).type(self.dtype)
                    self.Y_train = torch.from_numpy(train_data['Y']).type(self.dtype)
                    first_train = True
                else:
                    self.X_sub = torch.from_numpy(train_data['X'].astype(float)).type(self.dtype)
                    self.Y_sub = torch.from_numpy(train_data['Y']).type(self.dtype)
                    self.X_train = torch.cat((self.X_train, self.X_sub), dim=0)
                    self.Y_train = torch.cat((self.Y_train, self.Y_sub), dim=0)
        self.test_dataset = dataset(self.X_test, self.Y_test)
        self.train_dataset = dataset(self.X_train, self.Y_train)



class dataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, index):
        each_x = self.X[index]
        each_y = self.y[index]
        return each_x, each_y

class emotions(Dataset):
    pass
        
class genbase(Dataset):
    pass
    
class medical(Dataset):
    pass

class enron(Dataset):
    pass

class scene(Dataset):
    pass

class yeast(Dataset):
    pass

class corel5k(Dataset):
    pass

class rcv1subset1(Dataset):
    pass

class rcv1subset2(Dataset):
    pass

class bibtex(Dataset):
    pass

class delicious(Dataset):
    pass

class iaprtc12(Dataset):
    pass

class espgame(Dataset):
    pass

class mirflickr(Dataset):
    pass

class tmc2007(Dataset):
    pass

class mediamill(Dataset):
    pass
    
class CAL500(Dataset):
    pass

class language_log(Dataset):
    pass

class Image(Dataset):
    pass

class slashdot(Dataset):
    pass

class eurlex_directory_codes(Dataset):
    pass

class eurlex_subject_matters(Dataset):
    pass

class bookmarks(Dataset):
    pass

