import numpy as np
import h5py
import torch
import config as cfg
import os.path as op
import sys
import pickle
from reorganize_dataset import reorganize_WM, reorganize_SK
import argparse
parser = argparse.ArgumentParser()
import os



# reorganize the dataset
# yes(1) or no(0) if you have reorganized the dataset
parser.add_argument('--reorganize_dataset', type=int, default=0)
# the path of the dataset
# /home/public/NIPS/sparrKULee/
# /home/public/NIPS/watermelon_dataset/
parser.add_argument('--project_dir', type=str, default='/home/public/NIPS/watermelon_dataset/')
# the type of the dataset, the Watermelon dataset (WM) or the SparrKULee (SK) dataset
parser.add_argument('--dataset_type', type=str, default='WM')

# the band of the eeg data
# full, delta, theta, alpha, beta, low_gamma, high_gamma
parser.add_argument('--band', type=str, default='full')

# training and evaluating, all the code consists of the training and testing of the model
# the model runs very fast on the GPU
# 1: only evaluate the model, 0: train and evaluate the model, -1: only reorganize the dataset
parser.add_argument('--evaluate_only', type=int, default=0)
# xy: validate p(y|x), xz: validate p(z|x), zy: validate p(y|z), more information in paper, lt: leave-trails-out
# xz and zy are in the same code
# lt will get the results of chance level
parser.add_argument('--p', type=str, default='xz')

args, unknown = parser.parse_known_args()

if args.p == 'xy':
    from end_to_end import train_valid_model, test_model
elif args.p == 'xz' or args.p == 'zy':
    from domain_feature import train_valid_model, test_model
else:
    from without_domain_feature import train_valid_model, test_model

def train_and_test(eval_only=False, band='full', dataset_type='WM',save_dir='./DEAP_WM_full',p = 'xy'):
    # all the number of sbjects in the experiment
    # train one model for every subject
    # load the data

    data = np.zeros((cfg.sbnum,40,128*2*30,32))
    label = np.zeros((cfg.sbnum,40))

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    for i in range(len(sub_name)):
        eeg_savedir = op.join(save_dir, f'{sub_name[i]}.pkl')
        with open(eeg_savedir, 'rb') as f:
            eeg_data_label = pickle.load(f)

        data[i] = eeg_data_label['EEG']
        label[i] = eeg_data_label['label']

     # 0 or 1, representing the attended direction

    # random seed
    torch.manual_seed(2024)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(2024)

    res_xy = torch.zeros((cfg.sbnum,cfg.kfold_num))
    res_xz = torch.zeros((cfg.sbnum, cfg.kfold_num))
    res_zy = torch.zeros((cfg.sbnum, cfg.kfold_num))
    res_lt = torch.zeros((cfg.sbnum, 10))

    from sklearn.model_selection import KFold
    kfold = KFold(n_splits=cfg.kfold_num, shuffle=True, random_state=2024)

    for sb in range(0,cfg.sbnum):
        # get the data of specific subject
        eegdata = data[sb]
        eeglabel = label[sb]
        eeglabel = np.tile(eeglabel, (30, 1)).T


        if p != 'lt':
            eegdata = eegdata.reshape(40 * int(60 * 128 / cfg.decision_window), cfg.decision_window, 32)
            eeglabel = eeglabel.reshape(40 * int(60 * 128 / cfg.decision_window))

            # only used in domain_feature
            if p == 'xz' or p == 'zy':
                eegdomain = np.tile(np.arange(40), (int(60 * 128 / cfg.decision_window), 1)).T
                eegdomain = eegdomain.reshape(40 * int(60 * 128 / cfg.decision_window))

            for fold, (train_ids,  test_ids) in enumerate(kfold.split(eegdata)):

                if p == 'xy':
                    savedir = './model_' + dataset_type + '_' + band + '_' + p + '/sb' + str(sb)
                    saveckpt = savedir + '/fold' + str(fold) + '.ckpt'
                    if not os.path.exists(savedir):
                        os.makedirs(savedir)
                    if eval_only == 0:
                        train_valid_model(eegdata[train_ids], eeglabel[train_ids], saveckpt)
                    res_xy[sb,fold] = test_model(eegdata[test_ids], eeglabel[test_ids], saveckpt)
                else:
                    savedir = './model_' + dataset_type + '_' + band + '_xzy' + '/sb' + str(sb)
                    saveckpt = savedir + '/fold' + str(fold) + '.ckpt'
                    if not os.path.exists(savedir):
                        os.makedirs(savedir)
                    if eval_only == 0:
                        train_valid_model(eegdata[train_ids], eeglabel[train_ids], eegdomain[train_ids], saveckpt)
                    res_xz[sb, fold], res_zy[sb, fold] = test_model(eegdata[test_ids], eeglabel[test_ids], eegdomain[test_ids], saveckpt)

            print("good job!")
        else:
            # leave trials out, must keep the labels balanced
            for fold in range(10):
                test_id = [fold, fold + 10, fold + 20, fold + 30]
                train_id = [i for i in range(40) if i not in test_id]
                savedir = './model_' + dataset_type + '_' + band + '_' + p + '/sb' + str(sb)
                saveckpt = savedir + '/fold' + str(fold) + '.ckpt'
                if not os.path.exists(savedir):
                    os.makedirs(savedir)
                test_data = eegdata[test_id]
                test_label = eeglabel[test_id]

                test_data = test_data.reshape(4 * int(60 * 128 / cfg.decision_window), cfg.decision_window, 32)
                test_label = test_label.reshape(4 * int(60 * 128 / cfg.decision_window))
                if eval_only == 0:
                    train_data = eegdata[train_id]
                    train_label = eeglabel[train_id]
                    train_data = train_data.reshape(36 * int(60 * 128 / cfg.decision_window), cfg.decision_window, 32)
                    train_label = train_label.reshape(36 * int(60 * 128 / cfg.decision_window))
                    train_valid_model(train_data, train_label, saveckpt)
                res_lt[sb,fold] = test_model(test_data, test_label, saveckpt)


    saveres = 'DEAP' + '_' + dataset_type + '_' + band + '_' + p
    if p == 'xy':
        np.savetxt(saveres + '.csv', res_xy.numpy(), delimiter=',')
    elif p == 'xz' or p == 'zy':
        saveres = 'DEAP' + '_' + dataset_type + '_' + band + '_' + 'xz'
        np.savetxt(saveres + '.csv', res_xz.numpy(), delimiter=',')
        saveres = 'DEAP' + '_' + dataset_type + '_' + band + '_' + 'zy'
        np.savetxt(saveres + '.csv', res_zy.numpy(), delimiter=',')
    else:
        np.savetxt(saveres + '.csv', res_lt.numpy(), delimiter=',')



if __name__ == '__main__':
    project_dir = args.project_dir
    band = args.band
    dataset_type = args.dataset_type
    p = args.p
    cfg.p = p
    savedir = './DEAP' + '_' + dataset_type + '_' + band
    if args.reorganize_dataset == 1:
        if dataset_type == 'WM':
            reorganize_WM(project_dir, band, savedir)
        else:
            reorganize_SK(project_dir, band, savedir)
    if args.evaluate_only != -1:
        if args.evaluate_only == 1:
            train_and_test(True, band, dataset_type, savedir,p)
        else:
            train_and_test(False, band, dataset_type, savedir,p)


