import torch
import random
import numpy as np
from torch.utils import data
import scipy.io as sio
from sklearn.model_selection import KFold


# %% load & preprocess data
def load_data(fname,normalize_flag = True):
    data = sio.loadmat(fname)
    try:
        X_total = data['X']
        Y_total = data['Y']
    except:
        X_total = data['M']
        Y_total = data['L']
    
    if Y_total.shape[0]==1:
        Y_total = np.transpose(Y_total)

    # normalize X
    if normalize_flag:
        X_mean = np.mean(X_total, axis = 0)
        X_std = np.std(X_total,axis = 0)
        X_total = (X_total-X_mean)/(X_std+1e-6)

    
    # onehot Y
    Y_total = Y_total-min(Y_total)+1
    if len(Y_total.shape)==1 or Y_total.shape[1]==1:
        Y_total = Y_total.squeeze()
        c = np.max(Y_total)
        temp_d = np.eye(c)
        Y_total = temp_d[Y_total-1]

    return X_total,Y_total

# %% split_dataset
def split_dataset(data,kfold_num,ratio):
    train_idxes = []
    test_idxes = []
    if kfold_num>0:
        kf = KFold(n_splits = kfold_num)
        for train_idx,test_idx in kf.split(data):
            train_idxes.append(train_idx)
            test_idxes.append(test_idx)
    else:
        n = data.shape[0]
        list_total = [i for i in range(n)]

        train_num = round(n*ratio)
        train_samples = random.sample(range(n), train_num)

        test_samples = list(set(list_total).difference(set(train_samples)))
        
        train_idxes.append(train_samples)
        test_idxes.append(test_samples)
    
    return train_idxes,test_idxes


# %% generate_dataloader
def generate_dataloader(X,Y,train_idx,test_idx,batch_size = None):
    X_train = torch.tensor(X[train_idx,:])
    Y_train = torch.tensor(Y[train_idx,:])
    X_test = torch.tensor(X[test_idx,:])
    Y_test = torch.tensor(Y[test_idx,:])

    train_dataset = data.TensorDataset(X_train,Y_train)
    test_dataset = data.TensorDataset(X_test,Y_test)

    if batch_size is not None:
        train_iter = data.DataLoader(train_dataset,batch_size,shuffle = True)
        test_iter = data.DataLoader(test_dataset,batch_size,shuffle = False)
    else:
        train_iter = data.DataLoader(train_dataset,len(train_idx),shuffle = False)
        test_iter = data.DataLoader(test_dataset,len(test_idx),shuffle = False)
    return train_iter,test_iter

# %% process dataset
def select_dataset(fpath,normalize_flag = True):
    # prepare data
    X,Y = load_data(fpath,normalize_flag)
    # batch_size = 50
    kfold_num = 0
    ratio = 0.8
    train_idxes,test_idxes = split_dataset(X,kfold_num,ratio)
    train_idx = train_idxes[0]
    test_idx = test_idxes[0]
    X_train = torch.tensor(X[train_idx,:])
    Y_train = torch.tensor(Y[train_idx,:])
    X_test = torch.tensor(X[test_idx,:])
    Y_test = torch.tensor(Y[test_idx,:])

    # set network parameters
    n, d = X_train.shape
    c = Y.shape[1]
    return (X_train,Y_train),(X_test,Y_test),n,d,c
