import argparse
import logging
import os
import shutil
import sys

sys.path.append('../src')
import warnings

import matplotlib

matplotlib.use('agg')
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from utils.SREA_utils import main_wrapper
from utils.global_var import OUTPATH
from utils.log_utils import StreamToLogger
from utils.saver import Saver

from utils.seeg_datasets import seeg_datasets
from utils.utils import GPU_id

from utils.default_config import get_exp_dict, window_time_dict, slide_time_dict

num_threads = '16'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads

######################################################################################################

warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
columns = shutil.get_terminal_size().columns


######################################################################################################

def parse_args():
    # Add global parameters
    parser = argparse.ArgumentParser(
        description='SREA Single Experiment. It run n_runs independent experiments with different random seeds.'
                    ' Each run evaluate different noise ratios (ni).')

    parser.add_argument('--dataset', type=str, default='SEEG', help='SEEG datasets')

    # ! label noise ratio to be determined
    # parser.add_argument('--ni', type=float, nargs='+', default=[0, 0.30], help='label noise ratio')
    parser.add_argument('--ni', type=float, nargs='+', default=[0], help='label noise ratio')

    # ! noise type to be determined
    parser.add_argument('--label_noise', default=0, help='Label noise type, sym or int for asymmetric, '
                                                         'number as str for time-dependent noise')

    parser.add_argument('--M', type=int, nargs='+', default=[20, 40, 60, 80], help='Scheduler milestones')
    parser.add_argument('--abg', type=float, nargs='+',
                        help='Loss function coefficients. a (alpha) = AE, b (beta) = classifier, g (gamma) = clusterer',
                        default=[1, 1, 1])
    parser.add_argument('--class_reg', type=int, default=1, help='Distribution regularization coeff')
    parser.add_argument('--entropy_reg', type=int, default=0., help='Entropy regularization coeff')

    parser.add_argument('--correct', nargs='+', default=[True],
                        help='Correct labels. Set to false to not correct labels.')
    parser.add_argument('--track', type=int, default=5, help='Number or past predictions snapshots')
    parser.add_argument('--init_centers', type=int, default=1, help='Initialize cluster centers. Warm up phase.')
    parser.add_argument('--delta_start', type=int, default=10, help='Start re-labeling phase')
    parser.add_argument('--delta_end', type=int, default=30,
                        help='Begin fine-tuning phase')

    parser.add_argument('--preprocessing', type=str, default='StandardScaler',
                        help='Any available preprocessing method from sklearn.preprocessing')

    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--learning_rate', type=float, default=1e-2)
    parser.add_argument('--hidden_activation', type=str, default='nn.ReLU()')

    parser.add_argument('--normalization', type=str, default='batch')
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--l2penalty', type=float, default=1e-4)

    parser.add_argument('--num_workers', type=int, default=0, help='PyTorch dataloader worker. Set to 0 if debug.')
    parser.add_argument('--seed', type=int, default=1, help='Initial RNG seed. Only for reproducibility')

    # ! N_RUNS to be determined
    parser.add_argument('--n_runs', type=int, default=1, help='Number of runs, each run has different rng seed.')

    parser.add_argument('--classifier_dim', type=int, default=32, help='Dimension of final classifier')
    parser.add_argument('--embedding_size', type=int, default=3, help='Dimension of embedding')

    parser.add_argument('--kernel_size', type=int, default=4)
    parser.add_argument('--filters', nargs='+', type=int, default=[128, 128, 256, 256])
    parser.add_argument('--stride', type=int, default=2)
    parser.add_argument('--padding', type=int, default=2)

    # Suppress terminal out
    parser.add_argument('--disable_print', action='store_true', default=False,
                        help='Suppress screen print, keep log.txt')
    parser.add_argument('--plt_loss', action='store_true', default=False, help='plot loss function each epoch')
    parser.add_argument('--plt_embedding', action='store_true', default=False, help='plot embedding representation')
    parser.add_argument('--plt_loss_hist', action='store_true', default=False,
                        help='plot loss history for clean and mislabled samples')
    parser.add_argument('--plt_cm', action='store_true', default=False, help='plot confusion matrix')
    parser.add_argument('--plt_recons', action='store_true', default=False, help='plot AE reconstructions')
    parser.add_argument('--headless', action='store_true', default=True,
                        help='Matplotlib backend. Set true if no display connected.')

    # args for loading seeg database
    parser.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='Sleep',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--n_class', type=int, default=2,
                        help='class num of dataset')
    parser.add_argument('--exp_id', type=int, default=3,
                        help='The experimental id.')
    parser.add_argument('--gpu_id', type=int, default=1, help='id of cuda device')
    parser.add_argument('--noise_ratio', type=float, default=0,
                        help='The maximal ratio of adding noise.')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')

    parser.add_argument('--path_checkpoint', type=str, default='/data/SREA/')
    parser.add_argument('--save_step', type=int, default=20, help='The step number to save checkpoint')
    parser.add_argument('--patience', type=int, default=10, help='The waiting epoch number for early stopping.')
    parser.add_argument('--load_path', type=str, default=None, help='The path to load checkpoint.')
    parser.add_argument('--load_best', action='store_false', help='Whether to load the best state in the checkpoint.')
    parser.add_argument('--best_val_index', type=str, default='F1',
                        help='The index for saving models performing best in the validation dataset. The candidate'
                             'list includes: [loss, F1].')
    # Add parameters for each particular network
    args = parser.parse_args()

    return args


def clean_step(
        clean_label,
        ori_label,
        correct_label,
        evaluation_f,
        n_class,
):
    clean_label = clean_label.view(-1)
    ori_label = torch.argmax(ori_label, dim=-1).view(-1)
    # correct_label = torch.argmax(correct_label, dim=-1).view(-1)
    index = evaluation_f(
        clean_label.numpy(),
        correct_label.numpy(),
        n_class,
    )
    print('-' * 10, 'The clean results', '-' * 10)
    print(index)

    total_change_num = ori_label.ne(
        correct_label).sum()  # the total number of labels that were modified by the model during training
    print(f'The number of changed labels from original labels is: {total_change_num}')

    noise_index = torch.where(clean_label.ne(ori_label))[0]  # total noisy labels
    print(f'The number of total noisy labels is: {len(noise_index)}')
    total_corrected_num = clean_label[noise_index].eq(correct_label[noise_index]).sum()
    print(f'The number of corrected labels is: {total_corrected_num}')
    print(f'The ratio of corrected labels is: {total_corrected_num / len(noise_index)}')

    correct_index = torch.where(clean_label.eq(ori_label))[0]
    print(f'The number of total correct labels is: {len(correct_index)}')
    total_wrong_num = clean_label[correct_index].ne(correct_label[correct_index]).sum()
    print(f'The number of wrong labels is: {total_wrong_num}')
    print(f'The ratio of wrong labels is: {total_wrong_num / len(correct_index)}')

    return total_change_num, len(noise_index), total_corrected_num, total_corrected_num / len(noise_index), len(
        correct_index), total_wrong_num, total_wrong_num / len(correct_index)


######################################################################################################
def main():
    args = parse_args()
    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    # Declare saver object and get check_point path
    saver = Saver(OUTPATH, os.path.basename(__file__).split(sep='.py')[0],
                  hierarchy=os.path.join(args.dataset), noise_ratio=args.noise_ratio, exp_id=args.exp_id)

    args.path_checkpoint = saver.path

    print(args)
    print()

    ######################################################################################################
    SEED = args.seed
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device("cuda:{}".format(GPU_id.id) if torch.cuda.is_available() else "cpu")
    torch.manual_seed(SEED)
    if device != 'cpu':
        torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)

    if args.headless:
        print('Setting Headless support')
        plt.switch_backend('Agg')
    else:
        backend = 'Qt5Agg'
        print('Swtiching matplotlib backend to', backend)
        plt.switch_backend(backend)
    print()

    ######################################################################################################
    # LOG STUFF

    # Logging setting
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        filename=os.path.join(saver.path, 'logfile.log'),
        filemode='a'
    )

    # Redirect stdout
    stdout_logger = logging.getLogger('STDOUT')
    slout = StreamToLogger(stdout_logger, logging.INFO)
    sys.stdout = slout

    # Redirect stderr
    stderr_logger = logging.getLogger('STDERR')
    slerr = StreamToLogger(stderr_logger, logging.ERROR)
    sys.stderr = slerr

    # Suppress terminal output
    if args.disable_print:
        slout.terminal = open(os.devnull, 'w')
        slerr.terminal = open(os.devnull, 'w')

    ######################################################################################################
    # Data
    print('*' * shutil.get_terminal_size().columns)
    print('{} Dataset: {}'.format(args.data_name, args.dataset).center(columns))
    print('*' * shutil.get_terminal_size().columns)
    print()

    # X, Y = load_data(args.dataset)
    # x_train, x_test, Y_train_clean, Y_test_clean = train_test_split(X, Y, stratify=Y, test_size=0.2)

    train_data_handler, train_label, x_train, Y_train_clean, x_valid, Y_valid_clean, x_test, Y_test_clean, n_class = seeg_datasets(
        args)

    # Y_valid_clean = Y_test_clean.copy()
    # x_valid = x_test.copy()

    args.n_class = len(np.unique(Y_train_clean))

    if args.data_name != 'SEEG':
        clean_label = torch.tensor(train_data_handler.get_data(clean_label=True).label, dtype=torch.long)
    else:
        clean_label = torch.argmax(train_label, dim=-1)

    ori_label = train_label

    batch_size = min(x_train.shape[0] // 10, args.batch_size)
    if x_train.shape[0] % batch_size == 1:
        batch_size += -1
    print('Batch size: ', batch_size)
    args.batch_size = batch_size

    ###########################
    saver.make_log(**vars(args))

    ######################################################################################################
    df_results, correct_label = main_wrapper(args, x_train, x_valid, x_test, Y_train_clean, Y_valid_clean, Y_test_clean,
                                             saver)

    change_num, noisy_labels, corrected_labels, r_corrected_labels, correct_labels, wronged_labels, r_wronged_labels = clean_step(
        clean_label,
        ori_label,
        torch.from_numpy(correct_label),
        train_data_handler.model_evaluation,
        n_class,
    )

    test_results = {}
    test_results['change_num'] = change_num
    test_results['noisy_labels'] = noisy_labels
    test_results['corrected_labels'] = corrected_labels
    test_results['r_corrected_labels'] = r_corrected_labels
    test_results['correct_labels'] = correct_labels
    test_results['wronged_labels'] = wronged_labels
    test_results['r_wronged_labels'] = r_wronged_labels
    df_results = df_results.append(test_results, ignore_index=True)

    print('Save results')
    df_results.to_csv(os.path.join(saver.path, 'results.csv'), sep=',', index=False)


######################################################################################################
if __name__ == '__main__':
    main()
