import numpy as np
import h5py
import torch
import config as cfg
import os.path as op
import sys
import pickle
from reorganize_dataset import reorganize_WM, reorganize_SK
import argparse
parser = argparse.ArgumentParser()
import os
from end_to_end import train_valid_model, test_model
from sklearn.model_selection import train_test_split


# reorganize the dataset
# yes(1) or no(0) if you have reorganized the dataset
parser.add_argument('--reorganize_dataset', type=int, default=0)
# the path of the dataset
# /home/public/NIPS/sparrKULee/
# /home/public/NIPS/watermelon_dataset/
parser.add_argument('--project_dir', type=str, default='/home/public/NIPS/sparrKULee/')
# the type of the dataset, the Watermelon dataset (WM) or the SparrKULee (SK) dataset
parser.add_argument('--dataset_type', type=str, default='WM')

# the band of the eeg data
# full, delta, theta, alpha, beta, low_gamma, high_gamma
parser.add_argument('--band', type=str, default='full')

# training and evaluating, all the code consists of the training and testing of the model
# the model runs very fast on the GPU
# 1: only evaluate the model, 0: train and evaluate the model, -1: only reorganize the dataset
parser.add_argument('--evaluate_only', type=int, default=0)
# ls: leave-subjects-out only in test
# lsv: leave-subjects-out in validation and test

# more information in paper
parser.add_argument('--split', type=str, default='ls')

args, unknown = parser.parse_known_args()



def train_and_test(eval_only=False, band='full', dataset_type='WM',save_dir='./KUL_WM_full',split = 'ls'):
    # all the number of sbjects in the experiment
    # train one model for every subject
    # load the data
    data = np.zeros((cfg.sbnum,8,128*360,64))
    label = np.zeros((cfg.sbnum,8))

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']

     # 0 or 1, representing the attended direction

    # random seed
    torch.manual_seed(2024)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(2024)

    res = np.zeros((cfg.sbnum//2,3))

    for i in range(len(sub_name)):
        eeg_savedir = op.join(save_dir, f'{sub_name[i]}.pkl')
        with open(eeg_savedir, 'rb') as f:
            eeg_data_label = pickle.load(f)
        data[i] = eeg_data_label['EEG']
        label[i] = eeg_data_label['label']

    label = label[:, :, np.newaxis]
    label = np.repeat(label, 360, axis=2)

    for sb in range(0,cfg.sbnum//2):
        # get the data of specific subject
        test_id = [sb,sb+5]
        train_id = [i for i in range(10) if i not in test_id]
        savedir = './model_' + dataset_type + '_' + band  + '_' + split + '_' + str(sb)
        if not os.path.exists(savedir):
            os.makedirs(savedir)
        saveckpt = savedir + '/model.ckpt'



        # label = np.tile(label, (1, 1, 360))

        train_valid_data = data[train_id]
        train_valid_label = label[train_id]
        test_data = data[test_id]
        test_label = label[test_id]
        test_data = test_data.reshape(2 * 8 * int(360 * 128 / cfg.decision_window), cfg.decision_window, 64)
        test_label = test_label.reshape(2 * 8 * int(360 * 128 / cfg.decision_window))




        if split == 'ls':
            train_valid_data = train_valid_data.reshape(8 * 8 * int(360 * 128 / cfg.decision_window),
                                                        cfg.decision_window, 64)
            train_valid_label = train_valid_label.reshape(8 * 8 * int(360 * 128 / cfg.decision_window))
            train_data, valid_data, train_label, valid_label = train_test_split(train_valid_data, train_valid_label,
                                                                              test_size=0.125, random_state=2024)

            if eval_only == 0:
                res[sb,0],res[sb,1] = train_valid_model(train_data, train_label, valid_data, valid_label, saveckpt)
            res[sb,2] = test_model(test_data, test_label, saveckpt)

        elif split == 'lsv':
            random_valid = np.random.permutation(8)
            valid_id = [random_valid[0]]
            train_id = [i for i in range(8) if i not in valid_id]
            train_data = train_valid_data[train_id]
            train_label = train_valid_label[train_id]
            valid_data = train_valid_data[valid_id]
            valid_label = train_valid_label[valid_id]
            train_data = train_data.reshape(7 * 8 * int(360 * 128 / cfg.decision_window), cfg.decision_window, 64)
            train_label = train_label.reshape(7 * 8 * int(360 * 128 / cfg.decision_window))
            valid_data = valid_data.reshape(1 * 8 * int(360 * 128 / cfg.decision_window), cfg.decision_window, 64)
            valid_label = valid_label.reshape(1 * 8 * int(360 * 128 / cfg.decision_window))
            if eval_only == 0:
                res[sb,0],res[sb,1] = train_valid_model(train_data, train_label, valid_data, valid_label, saveckpt)
            res[sb,2] = test_model(test_data, test_label, saveckpt)


    saveres = 'KUL' + '_' + dataset_type + '_' + band + '_' + split
    # save tda,vda,res as csv
    np.savetxt(saveres + '_res.csv', res, delimiter=',')



if __name__ == '__main__':
    project_dir = args.project_dir
    band = args.band
    dataset_type = args.dataset_type
    split = args.split
    savedir = './KUL' + '_' + dataset_type + '_' + band
    if args.reorganize_dataset == 1:
        if dataset_type == 'WM':
            reorganize_WM(project_dir, band, savedir)
        else:
            reorganize_SK(project_dir, band, savedir)
    if args.evaluate_only != -1:
        if args.evaluate_only == 1:
            train_and_test(True, band, dataset_type, savedir,split)
        else:
            train_and_test(False, band, dataset_type, savedir,split)




