import argparse
import logging
import os
import shutil
import sys

sys.path.append('../src')
import warnings

import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from models.model import CNNAE
from models.MultiTaskClassification import NonLinClassifier, MetaModel
from utils.global_var import OUTPATH
from utils.saver import Saver

from utils.seeg_datasets import test_datasets
from utils.utils import GPU_id, evaluate_class
from utils.default_config import get_exp_dict, window_time_dict, slide_time_dict

num_threads = '16'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads

######################################################################################################

warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
columns = shutil.get_terminal_size().columns

torch.autograd.set_detect_anomaly(True)

device = torch.device("cuda:{}".format(GPU_id.id) if torch.cuda.is_available() else "cpu")


######################################################################################################

def parse_args():
    # TODO: make external configuration file -json or similar.
    """
    Parse arguments
    """
    # List handling: https://stackoverflow.com/questions/15753701/how-can-i-pass-a-list-as-a-command-line-argument-with-argparse

    # Add global parameters
    parser = argparse.ArgumentParser(description='sigua single experiment')

    # Synth Data
    parser.add_argument('--dataset', type=str, default='Plane', help='UCR datasets')

    parser.add_argument('--ni', type=float, nargs='+', default=[0], help='label noise ratio')
    parser.add_argument('--label_noise', type=int, default=0, help='Label noise type, sym or int for asymmetric, '
                                                                   'number as str for time-dependent noise')

    parser.add_argument('--M', type=int, nargs='+', default=[20, 40, 60, 80])
    parser.add_argument('--reg_term', type=float, default=1,
                        help="Parameter of the regularization term, default: 0.")
    parser.add_argument('--alpha', type=float, default=32,
                        help='alpha parameter for the mixup distribution, default: 32')

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--hidden_activation', type=str, default='nn.ReLU()')
    parser.add_argument('--num_gradual', type=int, default=100)
    parser.add_argument('--bad_weight', type=float, default=1e-3)

    parser.add_argument('--normalization', type=str, default='batch')
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--l2penalty', type=float, default=1e-4)

    parser.add_argument('--num_workers', type=int, default=0, help='PyTorch dataloader worker. Set to 0 if debug.')
    parser.add_argument('--seed', type=int, default=0, help='RNG seed - only affects Network init')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of runs')

    parser.add_argument('--classifier_dim', type=int, default=128)
    parser.add_argument('--embedding_size', type=int, default=3)

    parser.add_argument('--kernel_size', type=int, default=4)
    parser.add_argument('--filters', nargs='+', type=int, default=[128, 128, 256, 256])
    parser.add_argument('--stride', type=int, default=2)
    parser.add_argument('--padding', type=int, default=2)

    # Suppress terminal out
    parser.add_argument('--disable_print', action='store_true', default=False)
    parser.add_argument('--plt_embedding', action='store_true', default=False)
    parser.add_argument('--plt_loss_hist', action='store_true', default=False)
    parser.add_argument('--plt_cm', action='store_true', default=False)
    parser.add_argument('--headless', action='store_true', default=True, help='Matplotlib backend')

    # args for loading seeg database
    parser.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='Sleep',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--n_class', type=int, default=2,
                        help='class num of dataset')
    parser.add_argument('--exp_id', type=int, default=3,
                        help='The experimental id.')
    parser.add_argument('--noise_ratio', type=float, default=0,
                        help='The maximal ratio of adding noise.')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')

    # args for checkpoint
    parser.add_argument('--path_checkpoint', type=str, default='/data/SREA/')
    parser.add_argument('--save_step', type=int, default=20, help='The step number to save checkpoint')
    parser.add_argument('--patience', type=int, default=10, help='The waiting epoch number for early stopping.')
    parser.add_argument('--best_val_index', type=str, default='F1',
                        help='The index for saving models performing best in the validation dataset. The candidate'
                             'list includes: [loss, F1].')

    # Add parameters for each particular network
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    # Declare saver object and get check_point path
    saver = Saver(OUTPATH, 'sigua_single_experiment',
                  hierarchy=os.path.join(args.dataset), noise_ratio=args.noise_ratio, exp_id=args.exp_id)
    args.load_path = os.path.join(saver.path, 'checkpoint.pt')
    load_path = args.load_path

    # load data
    print('-' * 50)
    print('Load test dataset')
    x_test, Y_test, n_class = test_datasets(args)
    test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(Y_test).long())
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False,
                             num_workers=args.num_workers, pin_memory=True)
    args.n_class = len(np.unique(Y_test))

    # load checkpoint
    print('-' * 50)
    print('Load checkpoint:', load_path)
    state_dict = torch.load(load_path, 'cpu')

    best_loss_model = state_dict["BestLossModel"]
    best_f1_model = state_dict["BestF1Model"]

    df_results = pd.DataFrame()
    model_types = ['BestLoss', 'BestF1']
    for i, model_dict in enumerate([best_loss_model, best_f1_model]):
        # Network definition
        classifier = NonLinClassifier(args.embedding_size, n_class, d_hidd=args.classifier_dim, dropout=args.dropout,
                                      norm=args.normalization)
        model = CNNAE(input_size=x_test.shape[2], num_filters=args.filters, embedding_dim=args.embedding_size,
                      seq_len=x_test.shape[1], kernel_size=args.kernel_size, stride=args.stride,
                      padding=args.padding, dropout=args.dropout, normalization=args.normalization).to(device)
        model = MetaModel(ae=model, classifier=classifier, name='CNN').to(device)
        model.load_state_dict(model_dict, strict=False)
        model.to(device)

        test_results = evaluate_class(args, model, x_test, Y_test, None, test_loader, args.noise_ratio, saver, 'CNN',
                                      'Test', True, plt_cm=False, plt_lables=False)
        test_results['model_type'] = model_types[i]
        df_results = df_results.append(test_results, ignore_index=True)

    print('Save results')
    df_results.to_csv(os.path.join(saver.path, 'test_results.csv'), sep=',', index=False)


if __name__ == '__main__':
    main()
