import argparse
import logging
import os
import shutil
import sys

sys.path.append('../src')
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch

from utils.global_var import OUTPATH
from utils.log_utils import StreamToLogger
from utils.saver import Saver
from utils.training_helper_coteaching import main_wrapper

from utils.seeg_datasets import seeg_datasets
from utils.utils import GPU_id

from utils.default_config import get_exp_dict, window_time_dict, slide_time_dict

num_threads = '16'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads

######################################################################################################

warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
columns = shutil.get_terminal_size().columns

torch.autograd.set_detect_anomaly(True)


######################################################################################################

def parse_args():
    # TODO: make external configuration file -json or similar.
    """
    Parse arguments
    """
    # List handling: https://stackoverflow.com/questions/15753701/how-can-i-pass-a-list-as-a-command-line-argument-with-argparse

    # Add global parameters
    parser = argparse.ArgumentParser(description='coteaching single experiment')

    # Synth Data
    parser.add_argument('--dataset', type=str, default='Plane', help='UCR datasets')

    parser.add_argument('--ni', type=float, nargs='+', default=[0], help='label noise ratio')
    parser.add_argument('--label_noise', type=int, default=0, help='Label noise type, sym or int for asymmetric, '
                                                                   'number as str for time-dependent noise')

    parser.add_argument('--M', type=int, nargs='+', default=[20, 40, 60, 80])
    parser.add_argument('--reg_term', type=float, default=1,
                        help="Parameter of the regularization term, default: 0.")
    parser.add_argument('--alpha', type=float, default=32,
                        help='alpha parameter for the mixup distribution, default: 32')

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--hidden_activation', type=str, default='nn.ReLU()')
    parser.add_argument('--num_gradual', type=int, default=100)

    parser.add_argument('--normalization', type=str, default='batch')
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--l2penalty', type=float, default=1e-4)

    parser.add_argument('--num_workers', type=int, default=0, help='PyTorch dataloader worker. Set to 0 if debug.')
    parser.add_argument('--seed', type=int, default=0, help='RNG seed - only affects Network init')
    parser.add_argument('--n_runs', type=int, default=1, help='Number of runs')

    parser.add_argument('--classifier_dim', type=int, default=128)
    parser.add_argument('--embedding_size', type=int, default=3)

    parser.add_argument('--kernel_size', type=int, default=4)
    parser.add_argument('--filters', nargs='+', type=int, default=[128, 128, 256, 256])
    parser.add_argument('--stride', type=int, default=2)
    parser.add_argument('--padding', type=int, default=2)

    # Suppress terminal out
    parser.add_argument('--disable_print', action='store_true', default=False)
    parser.add_argument('--plt_embedding', action='store_true', default=False)
    parser.add_argument('--plt_loss_hist', action='store_true', default=False)
    parser.add_argument('--plt_cm', action='store_true', default=False)
    parser.add_argument('--headless', action='store_true', default=True, help='Matplotlib backend')

    # args for loading seeg database
    parser.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='Sleep',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--n_class', type=int, default=2,
                        help='class num of dataset')
    parser.add_argument('--gpu_id', type=int, default=1, help='id of cuda device')
    parser.add_argument('--exp_id', type=int, default=3,
                        help='The experimental id.')
    parser.add_argument('--noise_ratio', type=float, default=0,
                        help='The maximal ratio of adding noise.')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')

    # args for checkpoint
    parser.add_argument('--path_checkpoint', type=str, default='/data/SREA/')
    parser.add_argument('--save_step', type=int, default=20, help='The step number to save checkpoint')
    parser.add_argument('--patience', type=int, default=10, help='The waiting epoch number for early stopping.')
    parser.add_argument('--best_val_index', type=str, default='F1',
                        help='The index for saving models performing best in the validation dataset. The candidate'
                             'list includes: [loss, F1].')

    # Add parameters for each particular network
    args = parser.parse_args()
    return args


def clean_step(
        clean_label,
        ori_label,
        correct_label,
        evaluation_f,
        n_class,
):
    clean_label = clean_label.view(-1)
    ori_label = torch.argmax(ori_label, dim=-1).view(-1)
    # correct_label = torch.argmax(correct_label, dim=-1).view(-1)
    index = evaluation_f(
        clean_label.numpy(),
        correct_label.numpy(),
        n_class,
    )
    print('-' * 10, 'The clean results', '-' * 10)
    print(index)

    total_change_num = ori_label.ne(
        correct_label).sum()  # the total number of labels that were modified by the model during training
    print(f'The number of changed labels from original labels is: {total_change_num}')

    noise_index = torch.where(clean_label.ne(ori_label))[0]  # total noisy labels
    print(f'The number of total noisy labels is: {len(noise_index)}')
    total_corrected_num = clean_label[noise_index].eq(correct_label[noise_index]).sum()
    print(f'The number of corrected labels is: {total_corrected_num}')
    print(f'The ratio of corrected labels is: {total_corrected_num / len(noise_index)}')

    correct_index = torch.where(clean_label.eq(ori_label))[0]
    print(f'The number of total correct labels is: {len(correct_index)}')
    total_wrong_num = clean_label[correct_index].ne(correct_label[correct_index]).sum()
    print(f'The number of wrong labels is: {total_wrong_num}')
    print(f'The ratio of wrong labels is: {total_wrong_num / len(correct_index)}')

    return total_change_num, len(noise_index), total_corrected_num, total_corrected_num / len(noise_index), len(
        correct_index), total_wrong_num, total_wrong_num / len(correct_index)


######################################################################################################
def main():
    args = parse_args()
    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    # LOG STUFF
    # Declare saver object
    saver = Saver(OUTPATH, os.path.basename(__file__).split(sep='.py')[0],
                  hierarchy=os.path.join(args.dataset), noise_ratio=args.noise_ratio, exp_id=args.exp_id)
    args.path_checkpoint = saver.path

    ## Save json of args/parameters. This is handy for TL
    # with open(os.path.join(saver.path, 'args.json'), 'w') as fp:
    #    json.dump(vars(args), fp, indent=4)

    print('run logfile at: ', os.path.join(saver.path, 'logfile.log'))
    # Logging setting
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        filename=os.path.join(saver.path, 'logfile.log'),
        filemode='a'
    )

    # Redirect stdout
    stdout_logger = logging.getLogger('STDOUT')
    slout = StreamToLogger(stdout_logger, logging.INFO)
    sys.stdout = slout

    # Redirect stderr
    stderr_logger = logging.getLogger('STDERR')
    slerr = StreamToLogger(stderr_logger, logging.ERROR)
    sys.stderr = slerr

    # Suppress output
    if args.disable_print:
        slout.terminal = open(os.devnull, 'w')
        slerr.terminal = open(os.devnull, 'w')

    ######################################################################################################
    print(args)
    print()

    ######################################################################################################
    SEED = args.seed
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    # TODO: implement multi device and different GPU selection
    device = torch.device("cuda:{}".format(GPU_id.id) if torch.cuda.is_available() else "cpu")
    torch.manual_seed(SEED)
    if device != 'cpu':
        torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)

    if args.headless:
        print('Setting Headless support')
        plt.switch_backend('Agg')
    else:
        backend = 'Qt5Agg'
        print('Swtiching matplotlib backend to', backend)
        plt.switch_backend(backend)
    print()

    ######################################################################################################
    # Data
    print('*' * shutil.get_terminal_size().columns)
    print('{} Dataset: {}'.format(args.data_name, args.dataset).center(columns))
    print('*' * shutil.get_terminal_size().columns)
    print()

    # X, Y = load_data(args.dataset)
    # x_train, x_test, Y_train_clean, Y_test_clean = train_test_split(X, Y, stratify=Y, test_size=0.2)

    # data_handler, train_label, x_train, Y_train_clean, x_test, Y_test_clean, n_class = seeg_datasets(args)
    train_data_handler, train_label, x_train, Y_train_clean, x_valid, Y_valid_clean, x_test, Y_test_clean, n_class = seeg_datasets(
        args)
    args.n_class = len(np.unique(Y_train_clean))

    batch_size = min(x_train.shape[0] // 10, args.batch_size)
    if x_train.shape[0] % batch_size == 1:
        batch_size += -1
    print('Batch size: ', batch_size)
    args.batch_size = batch_size
    args.test_batch_size = batch_size

    ###########################
    saver.make_log(**vars(args))

    ######################################################################################################
    df_results = main_wrapper(args, x_train, x_valid, x_test, Y_train_clean, Y_valid_clean, Y_test_clean, saver)

    print('Save results')
    df_results.to_csv(os.path.join(saver.path, 'results.csv'), sep=',', index=False)

    return df_results


######################################################################################################
if __name__ == '__main__':
    main()
