import numpy as np
import h5py
import torch
import config as cfg
import os.path as op
import sys
import pickle
from reorganize_dataset import reorganize_WM, reorganize_SK
import argparse
parser = argparse.ArgumentParser()
import os



# reorganize the dataset
# yes(1) or no(0) if you have reorganized the dataset
# attention: in zero-shot tasks, we need to save the order of the blocks to verify the "zero-shot"
# is only decoding the domain feature
parser.add_argument('--reorganize_dataset', type=int, default=0)
# the path of the dataset
# /home/public/NIPS/sparrKULee/
# /home/public/NIPS/watermelon_dataset/
parser.add_argument('--project_dir', type=str, default='/home/public/NIPS/watermelon_dataset/')
# the type of the dataset, the Watermelon dataset (WM) or the SparrKULee (SK) dataset
parser.add_argument('--dataset_type', type=str, default='WM')

# the band of the eeg data
# full, delta, theta, alpha, beta, low_gamma, high_gamma
parser.add_argument('--band', type=str, default='full')

# training and evaluating, all the code consists of the training and testing of the model
# the model runs very fast on the GPU
# 1: only evaluate the model, 0: train and evaluate the model, -1: only reorganize the dataset
parser.add_argument('--evaluate_only', type=int, default=1)
# random: random select six blocks as the zeros-shot test set
# first: select the first six blocks as the zeros-shot test set
parser.add_argument('--split', type=str, default='random')

args, unknown = parser.parse_known_args()

from end_to_end import train_valid_model, test_model


def train_and_test(eval_only=False, band='full', dataset_type='WM', save_dir='./CVPR2017_WM_full', split='random'):
    # all the number of sbjects in the experiment
    # train one model for every subject
    # load the data

    data = np.zeros((cfg.sbnum, 40, 50, 500, 128))
    label = np.zeros((cfg.sbnum, 40))

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    for i in range(len(sub_name)):
        eeg_savedir = op.join(save_dir, f'{sub_name[i]}.pkl')
        with open(eeg_savedir, 'rb') as f:
            eeg_data_label = pickle.load(f)
        data[i] = eeg_data_label['EEG']
        label[i] = eeg_data_label['label']

     # 0 or 1, representing the attended direction


    # random seed
    torch.manual_seed(2024)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(2024)

    res = torch.zeros((cfg.sbnum,2))

    for sb in range(0,cfg.sbnum):
        savedir = './model_' + dataset_type + '_' + band + '_' + split
        if not os.path.exists(savedir):
            os.makedirs(savedir)
        saveckpt = savedir + '/' + str(sb) + '_model.ckpt'
        # get the data of specific subject

        eegdata = data[sb]
        eeglabel = label[sb]
        eeglabel = np.tile(eeglabel, (50, 1)).T

        # select the zero-shot test set
        if split == 'random':
            test_ids = np.random.choice(40, 6, replace=False)
        else:
            test_ids = np.arange(6)
        train_ids = np.setdiff1d(np.arange(40), test_ids)

        train_data = eegdata[train_ids]
        train_label = eeglabel[train_ids]
        test_data = eegdata[test_ids]
        test_label = eeglabel[test_ids]

        train_data = train_data.reshape(34 * 50 * int(500 / cfg.decision_window), cfg.decision_window, 128)
        train_label = train_label.reshape(34 * 50 * int(500 / cfg.decision_window))

        test_data = test_data.reshape(6 * 50 * int(500 / cfg.decision_window), cfg.decision_window, 128)
        test_label = test_label.reshape(6 * 50 * int(500 / cfg.decision_window))


        if eval_only == 0:
            train_valid_model(train_data, train_label, saveckpt)
        res[sb,0],res[sb,1] = test_model(test_data, test_label, saveckpt)

        print("good job!")

    saveres = 'CVPR2017' + '_' + dataset_type + '_' + band + '_' + split
    np.savetxt(saveres + '.csv', res.numpy(), delimiter=',')


if __name__ == '__main__':
    project_dir = args.project_dir
    band = args.band
    dataset_type = args.dataset_type
    split = args.split

    savedir = './CVPR2017' + '_' + dataset_type + '_' + band
    if args.reorganize_dataset == 1:
        if dataset_type == 'WM':
            reorganize_WM(project_dir, band, savedir)
        else:
            reorganize_SK(project_dir, band, savedir)
    if args.evaluate_only != -1:
        if args.evaluate_only == 1:
            train_and_test(True, band, dataset_type, savedir,split)
        else:
            train_and_test(False, band, dataset_type, savedir,split)




