import argparse
import logging
import os
import shutil
import sys

sys.path.append('../src')
import warnings

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset

from models.MultiTaskClassification import AEandClass, NonLinClassifier
from models.model import CNNAE
from utils.global_var import OUTPATH
from utils.saver import Saver
from utils.utils import evaluate_class_recons, GPU_id
from utils.seeg_datasets import test_datasets

from utils.default_config import get_exp_dict, window_time_dict, slide_time_dict

######################################################################################################

warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
columns = shutil.get_terminal_size().columns

device = torch.device("cuda:{}".format(GPU_id.id) if torch.cuda.is_available() else "cpu")


######################################################################################################


def parse_args():
    # Add global parameters
    parser = argparse.ArgumentParser(
        description='SREA Single Experiment. It run n_runs independent experiments with different random seeds.'
                    ' Each run evaluate different noise ratios (ni).')

    parser.add_argument('--dataset', type=str, default='SEEG', help='SEEG datasets')

    # ! label noise ratio to be determined
    # parser.add_argument('--ni', type=float, nargs='+', default=[0, 0.30], help='label noise ratio')
    parser.add_argument('--ni', type=float, nargs='+', default=[0], help='label noise ratio')

    # ! noise type to be determined
    parser.add_argument('--label_noise', default=0, help='Label noise type, sym or int for asymmetric, '
                                                         'number as str for time-dependent noise')

    parser.add_argument('--M', type=int, nargs='+', default=[20, 40, 60, 80], help='Scheduler milestones')
    parser.add_argument('--abg', type=float, nargs='+',
                        help='Loss function coefficients. a (alpha) = AE, b (beta) = classifier, g (gamma) = clusterer',
                        default=[1, 1, 1])
    parser.add_argument('--class_reg', type=int, default=1, help='Distribution regularization coeff')
    parser.add_argument('--entropy_reg', type=int, default=0., help='Entropy regularization coeff')

    parser.add_argument('--correct', nargs='+', default=[True],
                        help='Correct labels. Set to false to not correct labels.')
    parser.add_argument('--track', type=int, default=5, help='Number or past predictions snapshots')
    parser.add_argument('--init_centers', type=int, default=1, help='Initialize cluster centers. Warm up phase.')
    parser.add_argument('--delta_start', type=int, default=10, help='Start re-labeling phase')
    parser.add_argument('--delta_end', type=int, default=30,
                        help='Begin fine-tuning phase')

    parser.add_argument('--preprocessing', type=str, default='StandardScaler',
                        help='Any available preprocessing method from sklearn.preprocessing')

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--learning_rate', type=float, default=1e-2)
    parser.add_argument('--hidden_activation', type=str, default='nn.ReLU()')

    parser.add_argument('--normalization', type=str, default='batch')
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--l2penalty', type=float, default=1e-4)

    parser.add_argument('--num_workers', type=int, default=0, help='PyTorch dataloader worker. Set to 0 if debug.')
    parser.add_argument('--seed', type=int, default=1, help='Initial RNG seed. Only for reproducibility')

    # ! N_RUNS to be determined
    parser.add_argument('--n_runs', type=int, default=1, help='Number of runs, each run has different rng seed.')

    parser.add_argument('--classifier_dim', type=int, default=32, help='Dimension of final classifier')
    parser.add_argument('--embedding_size', type=int, default=3, help='Dimension of embedding')

    parser.add_argument('--kernel_size', type=int, default=4)
    parser.add_argument('--filters', nargs='+', type=int, default=[128, 128, 256, 256])
    parser.add_argument('--stride', type=int, default=2)
    parser.add_argument('--padding', type=int, default=2)

    # Suppress terminal out
    parser.add_argument('--disable_print', action='store_true', default=False,
                        help='Suppress screen print, keep log.txt')
    parser.add_argument('--plt_loss', action='store_true', default=False, help='plot loss function each epoch')
    parser.add_argument('--plt_embedding', action='store_true', default=False, help='plot embedding representation')
    parser.add_argument('--plt_loss_hist', action='store_true', default=False,
                        help='plot loss history for clean and mislabled samples')
    parser.add_argument('--plt_cm', action='store_true', default=False, help='plot confusion matrix')
    parser.add_argument('--plt_recons', action='store_true', default=False, help='plot AE reconstructions')
    parser.add_argument('--headless', action='store_true', default=True,
                        help='Matplotlib backend. Set true if no display connected.')

    # args for loading seeg database
    parser.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='Sleep',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--n_class', type=int, default=2,
                        help='class num of dataset')
    parser.add_argument('--exp_id', type=int, default=3,
                        help='The experimental id.')
    parser.add_argument('--gpu_id', type=int, default=1, help='id of cuda device')
    parser.add_argument('--noise_ratio', type=float, default=0,
                        help='The maximal ratio of adding noise.')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')

    parser.add_argument('--path_checkpoint', type=str, default='/data/SREA/')
    parser.add_argument('--save_step', type=int, default=20, help='The step number to save checkpoint')
    parser.add_argument('--patience', type=int, default=10, help='The waiting epoch number for early stopping.')
    parser.add_argument('--load_path', type=str, default=None, help='The path to load checkpoint.')
    parser.add_argument('--load_best', action='store_false', help='Whether to load the best state in the checkpoint.')
    parser.add_argument('--best_val_index', type=str, default='F1',
                        help='The index for saving models performing best in the validation dataset. The candidate'
                             'list includes: [loss, F1].')

    # Add parameters for each particular network
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    # Declare saver object and get check_point path
    saver = Saver(OUTPATH, 'SEEG_SREA_single_experiment',
                  hierarchy=os.path.join(args.dataset), noise_ratio=args.noise_ratio, exp_id=args.exp_id)
    args.load_path = os.path.join(saver.path, 'checkpoint.pt')
    load_path = args.load_path

    # load data
    print('-' * 50)
    print('Load test dataset')
    x_test, Y_test, n_class = test_datasets(args)
    test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(Y_test).long())
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False,
                             num_workers=args.num_workers, pin_memory=True)
    args.n_class = len(np.unique(Y_test))

    # load checkpoint
    print('-' * 50)
    print('Load checkpoint:', load_path)
    state_dict = torch.load(load_path, 'cpu')

    best_loss_model = state_dict["BestLossModel"]
    best_loss_model = state_dict["BestLossModel"]
    best_f1_model = state_dict["BestF1Model"]

    df_results = pd.DataFrame()
    model_types = ['BestLoss', 'BestF1']
    for i, model_dict in enumerate([best_loss_model, best_f1_model]):
        # Network definition
        classifier = NonLinClassifier(args.embedding_size, n_class, d_hidd=args.classifier_dim, dropout=args.dropout,
                                      norm=args.normalization)

        model_ae = CNNAE(input_size=x_test.shape[2], num_filters=args.filters, embedding_dim=args.embedding_size,
                         seq_len=x_test.shape[1], kernel_size=args.kernel_size, stride=args.stride,
                         padding=args.padding, dropout=args.dropout, normalization=args.normalization)

        # model is multi task - AE Branch and Classification branch
        model = AEandClass(ae=model_ae, classifier=classifier, name='CNN')
        model.load_state_dict(model_dict, strict=False)
        model.to(device)

        test_results = evaluate_class_recons(model, x_test, Y_test, None, test_loader, args.noise_ratio, saver, 'CNN',
                                             'Test', args.correct, args.n_class)
        test_results['model_type'] = model_types[i]
        df_results = df_results.append(test_results, ignore_index=True)

    print('Save results')
    df_results.to_csv(os.path.join(saver.path, 'test_results.csv'), sep=',', index=False)


if __name__ == '__main__':
    main()
