import numpy as np
import h5py
import torch
import config as cfg
import os.path as op
import sys
import pickle
from reorganize_dataset import reorganize_WM, reorganize_SK
import argparse
parser = argparse.ArgumentParser()
import os



# reorganize the dataset
# yes(1) or no(0) if you have reorganized the dataset
parser.add_argument('--reorganize_dataset', type=int, default=0)
# the path of the dataset
# /home/public/NIPS/sparrKULee/
# /home/public/NIPS/watermelon_dataset/
parser.add_argument('--project_dir', type=str, default='/home/public/NIPS/watermelon_dataset/')
# the type of the dataset, the Watermelon dataset (WM) or the SparrKULee (SK) dataset
parser.add_argument('--dataset_type', type=str, default='WM')

# the band of the eeg data
# full, delta, theta, alpha, beta, low_gamma, high_gamma
parser.add_argument('--band', type=str, default='full')

# training and evaluating, all the code consists of the training and testing of the model
# the model runs very fast on the GPU
# 1: only evaluate the model, 0: train and evaluate the model, -1: only reorganize the dataset
parser.add_argument('--evaluate_only', type=int, default=0)
# xy: validate p(y|x), xz: validate p(z|x), zy: validate p(y|z), more information in paper, lt: leave-trails-out
# xz and zy are in the same code
# lt will get the results of chance level
parser.add_argument('--p', type=str, default='xy')

args, unknown = parser.parse_known_args()

from end_to_end import train_valid_model, test_model


def train_and_test(eval_only=False, band='full', dataset_type='WM', save_dir='./CVPR2017_WM_full', p='xy'):
    # all the number of sbjects in the experiment
    # train one model for every subject
    # load the data
    p = 'xy' # only xy need to be implemented, more information in paper
    data = np.zeros((cfg.sbnum, 40, 50, 500, 128))
    label = np.zeros((cfg.sbnum, 40))

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    for i in range(len(sub_name)):
        eeg_savedir = op.join(save_dir, f'{sub_name[i]}.pkl')
        with open(eeg_savedir, 'rb') as f:
            eeg_data_label = pickle.load(f)
        data[i] = eeg_data_label['EEG']
        label[i] = eeg_data_label['label']

     # 0 or 1, representing the attended direction

    # random seed
    torch.manual_seed(2024)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(2024)

    res_xy = torch.zeros((cfg.sbnum,cfg.kfold_num))
    res_xz = torch.zeros((cfg.sbnum, cfg.kfold_num))
    res_zy = torch.zeros((cfg.sbnum, cfg.kfold_num))

    from sklearn.model_selection import KFold
    kfold = KFold(n_splits=cfg.kfold_num, shuffle=True, random_state=2024)

    for sb in range(0,cfg.sbnum):
        # get the data of specific subject
        eegdata = data[sb]
        eeglabel = label[sb]
        eeglabel = np.tile(eeglabel, (50, 1)).T
        if p != 'lt':
            eegdata = eegdata.reshape(40 * int(50 * 500 / cfg.decision_window), cfg.decision_window, 128)
            eeglabel = eeglabel.reshape(40 * int(50 * 500 / cfg.decision_window))

            # only used in domain_feature
            if p == 'xz' or p == 'zy':
                eegdomain = np.tile(np.arange(40), (50, 1)).T
                eegdomain = eegdomain.reshape(40 * int(50 * 500 / cfg.decision_window))

            for fold, (train_ids,  test_ids) in enumerate(kfold.split(eegdata)):

                if p == 'xy':
                    savedir = './model_' + dataset_type + '_' + band + '_' + p + '/sb' + str(sb)
                    saveckpt = savedir + '/fold' + str(fold) + '.ckpt'
                    if not os.path.exists(savedir):
                        os.makedirs(savedir)
                    if eval_only == 0:
                        train_valid_model(eegdata[train_ids], eeglabel[train_ids], saveckpt)
                    res_xy[sb,fold] = test_model(eegdata[test_ids], eeglabel[test_ids], saveckpt)
                else:
                    savedir = './model_' + dataset_type + '_' + band + '_xzy' + '/sb' + str(sb)
                    saveckpt = savedir + '/fold' + str(fold) + '.ckpt'
                    if not os.path.exists(savedir):
                        os.makedirs(savedir)
                    if eval_only == 0:
                        train_valid_model(eegdata[train_ids], eeglabel[train_ids], eegdomain[train_ids], saveckpt)
                    res_xz[sb, fold], res_zy[sb, fold] = test_model(eegdata[test_ids], eeglabel[test_ids], eegdomain[test_ids], saveckpt)

            print("good job!")

    saveres = 'CVPR2017' + '_' + dataset_type + '_' + band + '_' + p
    np.savetxt(saveres + '.csv', res_xy.numpy(), delimiter=',')


if __name__ == '__main__':
    project_dir = args.project_dir
    band = args.band
    dataset_type = args.dataset_type
    p = args.p
    cfg.p = p
    savedir = './CVPR2017' + '_' + dataset_type + '_' + band
    if args.reorganize_dataset == 1:
        if dataset_type == 'WM':
            reorganize_WM(project_dir, band, savedir)
        else:
            reorganize_SK(project_dir, band, savedir)
    if args.evaluate_only != -1:
        if args.evaluate_only == 1:
            train_and_test(True, band, dataset_type, savedir,p)
        else:
            train_and_test(False, band, dataset_type, savedir,p)




