import os
import json
import argparse
import sys

from model_tool import sktime_dataset, sktime_model

sys.path.append('../../')
sys.path.append('../../utils')

from utils.default_config import set_default_cpc_config, get_choice_default_config
from utils.excel_manager import ExcelManager

num_threads = '20'
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SktimeModel')
    parser = set_default_cpc_config(parser)

    group_train = parser.add_argument_group('Train')
    group_train.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                             help='Should give a path to load the database of one patient.')
    group_train.add_argument('--data_name', type=str, default='Sleep',
                             help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    group_train.add_argument('--noise_ratio', type=float, default=.0,
                             help='The maximal ratio of adding noise.')
    group_train.add_argument('--exp_id', type=int, default=1,
                             help='The experimental id.')
    group_train.add_argument('--model_name', type=str, default='minirocket',
                             help='The model name in sktime.')
    group_train.add_argument('--n_jobs', type=int, default=20,
                             help='The number of jobs to run in parallel.')
    group_train.add_argument('--load_path', type=str, default='/data/CL_result/',
                             help='The path to load checkpoint.')
    group_train.add_argument('--load_model', type=bool, default=False,
                             help='Whether to load checkpoint.')
    group_train.add_argument('--save_model', type=bool, default=True,
                             help='Whether to save checkpoint.')
    group_train.add_argument('--summary', type=bool, default=False,
                             help='Whether to summary the results of all experiments.')
    argv = sys.argv[1:]
    args = parser.parse_args(argv)
    args, _ = get_choice_default_config(args)

    # Save to Excel file
    if args.model_label:
        args.load_path = os.path.join(args.load_path, f'{args.data_name}_C/{args.model_name}/')
    else:
        args.load_path = os.path.join(args.load_path, f'{args.data_name}/{args.model_name}/')
    if args.data_name != 'SEEG':
        args.load_path = os.path.join(args.load_path, f'{int(args.noise_ratio * 100)}/')
    args.load_path = os.path.join(args.load_path, f'exp{args.exp_id}')
    if not os.path.exists(args.load_path):
        os.makedirs(args.load_path)
    excel = ExcelManager(args.load_path, 'test_result')
    if args.summary:
        excel.summary_results()
        sys.exit(0)

    data_handler, x_train, y_train, x_test, y_test, n_class = sktime_dataset(args)

    print(f'CONFIG:\n{json.dumps(vars(args), indent=4, sort_keys=True)}')
    print('-' * 50)

    y_pred = sktime_model(
        args.model_name,
        x_train,
        y_train,
        x_test,
        args.n_jobs,
        load_dir=args.load_path if args.load_model else None,
        save_dir=args.load_path if args.save_model else None,
    )
    index = data_handler.model_evaluation(
        y_test,
        y_pred,
        n_class,
    )
    print('-' * 10, 'The average testing results of ' + args.model_name, '-' * 10)
    print(index)

    excel.res2excel(str(index), tar_pat_name=f'exp{args.exp_id}')
    excel.excel_save(args.exp_id)
