import argparse
import os
import sys
import torch
from exp.exp_classification import Exp_Classification
import random

sys.path.append('../../')
sys.path.append('../../utils')

from utils.default_config import get_exp_dict, window_time_dict, slide_time_dict
from utils.misc import set_seed
from utils.excel_manager import ExcelManager

num_threads = '20'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='TSLib')

    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')
    parser.add_argument('--n_process_loader', type=int, default=50,
                        help='Number of processes to call to load the dataset.')
    parser.add_argument('--random_seed', type=int, default=None,
                        help="Set a specific random seed.")
    parser.add_argument('--model_label', action='store_true',
                        help="Whether to use the corrected labels for training.")

    # basic config
    parser.add_argument('--task_name', type=str, default='classification')
    parser.add_argument('--is_training', type=int, default=1, help='status')
    parser.add_argument('--model', type=str, default='PatchTST',
                        help='model name, options: [TimesNet, PatchTST]')
    parser.add_argument('--summary', type=bool, default=False,
                        help='Whether to summary the results of all experiments.')

    # data loader
    parser.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                        help='Should give a path to load the database of patients.')
    parser.add_argument('--data_name', type=str, default='Sleep',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--noise_ratio', type=float, default=.0,
                        help='The maximal ratio of adding noise.')
    parser.add_argument('--exp_id', type=int, default=1,
                        help='The experimental id.')
    parser.add_argument('--path_checkpoint', type=str, default='/data/CL_result/',
                        help='The path to save checkpoint.')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:'
                             '[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], '
                             'you can also use more detailed freq like 15min or 3h')

    # model define
    parser.add_argument('--top_k', type=int, default=3, help='for TimesBlock')
    parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=7, help='output size')
    parser.add_argument('--d_model', type=int, default=64, help='dimension of model')
    parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
    parser.add_argument('--e_layers', type=int, default=3, help='num of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
    parser.add_argument('--d_ff', type=int, default=128, help='dimension of fcn')
    parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
    parser.add_argument('--factor', type=int, default=1, help='attn factor')
    parser.add_argument('--distil', action='store_false',
                        help='whether to use distilling in encoder, using this argument means not using distilling',
                        default=True)
    parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')
    parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')

    # optimization
    parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
    parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
    parser.add_argument('--patience', type=int, default=10, help='early stopping patience')
    parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
    parser.add_argument('--des', type=str, default='test', help='exp description')
    parser.add_argument('--loss', type=str, default='MSE', help='loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
    parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--gpu_id', type=int, default=2, help='gpu')

    # de-stationary projector params
    parser.add_argument('--p_hidden_dims', type=int, nargs='+', default=[128, 128],
                        help='hidden layer dimensions of projector (List)')
    parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector')

    argv = sys.argv[1:]
    args = parser.parse_args(argv)
    if args.random_seed is None:
        args.random_seed = random.randint(0, 2 ** 31)
    set_seed(args.random_seed)

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    # Save to Excel file
    if args.model_label:
        args.path_checkpoint = os.path.join(args.path_checkpoint, f'{args.data_name}_C/{args.model}/')
    else:
        args.path_checkpoint = os.path.join(args.path_checkpoint, f'{args.data_name}/{args.model}/')
    if args.data_name != 'SEEG':
        args.path_checkpoint = os.path.join(args.path_checkpoint, f'{int(args.noise_ratio * 100)}/')
    args.path_checkpoint = os.path.join(args.path_checkpoint, f'exp{args.exp_id}')
    if not os.path.exists(args.path_checkpoint):
        os.makedirs(args.path_checkpoint)
    excel = ExcelManager(args.path_checkpoint, 'test_result')
    if args.summary:
        excel.summary_results()
        sys.exit(0)

    print('Args in experiment:')
    print(args)

    Exp = Exp_Classification

    if args.is_training:
        exp = Exp(args)  # set experiments
        print('>>>>>>>start training : >>>>>>>>>>>>>>>>>>>>>>>>>>')
        exp.train()

        print('>>>>>>>testing : <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
        index = exp.test()
    else:
        exp = Exp(args)  # set experiments
        print('>>>>>>>testing : <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
        index = exp.test(test=1)

    excel.res2excel(str(index), tar_pat_name=f'exp{args.exp_id}')
    excel.excel_save(args.exp_id)
