import torch
import numpy as np
import argparse
import os
import time
from copy import deepcopy

from torch import optim
import random
import sys

sys.path.append('../../../utils/')
sys.path.append('../')
sys.path.append('../utils')

from utils.utils_noise_v2 import *
from utils.test_eval import test_eval
from utils.queue_with_pro import *
from utils.kNN_test_v2 import *
from utils.MemoryMoCo import MemoryMoCo
from utils.other_utils import *
from utils.models.preact_resnet import *
from utils.lr_scheduler import get_scheduler
from apex import amp

from torch.utils.data import DataLoader, TensorDataset
from utils.blurred_datasets import seeg_datasets
from default_config import get_exp_dict, window_time_dict, slide_time_dict

num_threads = '10'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads


def parse_args():
    parser = argparse.ArgumentParser(description='command for the first train')
    parser.add_argument('--epoch', type=int, default=100, help='training epoches')
    parser.add_argument('--warmup_way', type=str, default="uns", help='uns, sup')
    parser.add_argument('--warmup-epoch', type=int, default=1, help='warmup epoch')
    parser.add_argument('--lr', '--base-learning-rate', '--base-lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--lr-scheduler', type=str, default='step',
                        choices=["step", "cosine"], help="learning rate scheduler")
    parser.add_argument('--lr-warmup-epoch', type=int, default=1, help='warmup epoch')
    parser.add_argument('--lr-warmup-multiplier', type=int, default=100, help='warmup multiplier')
    parser.add_argument('--lr-decay-epochs', type=int, default=[125, 200], nargs='+',
                        help='for step scheduler. where to decay lr, can be a list')
    parser.add_argument('--lr-decay-rate', type=float, default=0.1,
                        help='for step scheduler. decay rate for learning rate')
    parser.add_argument('--initial_epoch', type=int, default=1, help="Star training at initial_epoch")

    parser.add_argument('--batch_size', type=int, default=128, help='#images in each mini-batch')
    parser.add_argument('--test_batch_size', type=int, default=100, help='#images in each mini-batch')
    parser.add_argument('--cuda_dev', type=int, default=0, help='GPU to select')  # GPU number
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--noise_type', default='asymmetric', help='symmetric or asymmetric')
    parser.add_argument('--train_root', default='./dataset', help='root for train data')
    parser.add_argument('--out', type=str, default='./output', help='Directory of the output')
    parser.add_argument('--experiment_name', type=str, default='Proof',
                        help='name of the experiment (for the output files)')
    parser.add_argument('--dataset', type=str, default='CIFAR-10', help='CIFAR-10, CIFAR-100')
    parser.add_argument('--download', type=bool, default=False, help='download dataset')

    parser.add_argument('--network', type=str, default='PR18', help='Network architecture')
    parser.add_argument('--low_dim', type=int, default=128, help='Size of contrastive learning embedding')
    parser.add_argument('--headType', type=str, default="Linear", help='Linear, NonLinear')
    parser.add_argument('--seed_initialization', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--seed_dataset', type=int, default=42, help='random seed (default: 1)')
    parser.add_argument('--DA', type=str, default="complex", help='Choose simple or complex data augmentation')

    parser.add_argument('--alpha_m', type=float, default=1.0, help='Beta distribution parameter for mixup')
    parser.add_argument('--alpha_moving', type=float, default=0.999, help='exponential moving average weight')
    parser.add_argument('--alpha', type=float, default=0.5, help='example selection th')
    parser.add_argument('--beta', type=float, default=0.25, help='pair selection th')
    parser.add_argument('--uns_queue_k', type=int, default=10000, help='uns-cl num negative sampler')
    parser.add_argument('--uns_t', type=float, default=0.1, help='uns-cl temperature')
    parser.add_argument('--sup_t', default=0.1, type=float, help='sup-cl temperature')
    parser.add_argument('--sup_queue_use', type=int, default=1, help='1: Use queue for sup-cl')
    parser.add_argument('--sup_queue_begin', type=int, default=3, help='Epoch to begin using queue for sup-cl')
    parser.add_argument('--queue_per_class', type=int, default=1000,
                        help='Num of samples per class to store in the queue. queue size = queue_per_class*num_classes*2')
    parser.add_argument('--aprox', type=int, default=1,
                        help='Approximation for numerical stability taken from supervised contrastive learning')
    parser.add_argument('--lambda_s', type=float, default=0.01, help='weight for similarity loss')
    parser.add_argument('--lambda_c', type=float, default=1, help='weight for classification loss')
    parser.add_argument('--k_val', type=int, default=250, help='k for k-nn correction')

    parser.add_argument('--database_save_dir', type=str, default='./data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='fNIRS_2',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--exp_id', type=int, default=1,
                        help='The experimental id.')
    parser.add_argument('--num_classes', type=int, default=2, help='Number of in-distribution classes')
    parser.add_argument('--noise_ratio', type=float, default=0.4, help='percent of noise')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')

    parser.add_argument('--patience', type=int, default=10, help='patience fot early stopping')

    args = parser.parse_args()

    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    return args


def blurred_data_config(args):
    _, _, x_train, y_train, x_valid, y_valid, x_test, y_test, n_class = seeg_datasets(args)
    args.num_classes = n_class

    train_dataset = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long(),
                                  torch.from_numpy(np.arange(len(y_train))))
    valid_dataset = TensorDataset(torch.from_numpy(x_valid).float(), torch.from_numpy(y_valid).long())
    test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False,
                              num_workers=8, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False,
                              num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False,
                             num_workers=8, pin_memory=True)

    return train_loader, valid_loader, train_dataset, x_train.shape[-2]


def build_model(args, input_size, device):
    model = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=args.low_dim,
                           head=args.headType).to(device)
    model_ema = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=args.low_dim,
                               head=args.headType).to(device)

    # copy weights from `model' to `model_ema'
    moment_update(model, model_ema, 0)
    return model, model_ema


def main(args):
    # define result path
    exp_path = os.path.join(args.out, 'Sel-CL_experiment', args.data_name, str(int(args.noise_ratio * 100)),
                            f'exp{args.exp_id}')
    res_path = exp_path

    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    # set log
    __console__ = sys.stdout
    name = "/results"
    log_file = open(res_path + name + ".log", 'a')
    sys.stdout = log_file
    print(args)

    args.best_acc = 0
    best_acc5 = 0
    best_acc_val = 0.0
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_dev)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed_initialization)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed_initialization)  # GPU seed

    random.seed(args.seed_initialization)  # python seed for image transformation

    # data loader
    train_loader, test_loader, trainset, input_size = blurred_data_config(args)

    model, model_ema = build_model(args, input_size, device)

    uns_contrast = MemoryMoCo(args.low_dim, args.uns_queue_k, args.uns_t, thresh=0).cuda()

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1", num_losses=2)
    scheduler = get_scheduler(optimizer, len(train_loader), args)

    if args.sup_queue_use == 1:
        queue = queue_with_pro(args, device)
    else:
        queue = []

    best_valid_f1 = -1
    early_stop_counter = 0
    best_f1_model_state = deepcopy(model.state_dict())
    best_f1_optimizer_state = deepcopy(optimizer.state_dict())

    for epoch in range(args.initial_epoch, args.epoch + 1):
        st = time.time()
        print("=================>    ", args.experiment_name, args.noise_ratio)
        if (epoch <= args.warmup_epoch):
            if (args.warmup_way == 'uns'):
                train_uns(args, scheduler, model, model_ema, uns_contrast, queue, device, train_loader, optimizer,
                          epoch)
            else:
                train_selected_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=4,
                                                                    pin_memory=True,
                                                                    sampler=torch.utils.data.WeightedRandomSampler(
                                                                        torch.ones(len(trainset)), len(trainset)))
                train_sup(args, scheduler, model, model_ema, uns_contrast, queue, device, train_loader,
                          train_selected_loader, optimizer, epoch)
        else:
            train_selected_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=4,
                                                                pin_memory=True,
                                                                sampler=torch.utils.data.WeightedRandomSampler(
                                                                    selected_examples, len(selected_examples)))
            train_sel(args, scheduler, model, model_ema, uns_contrast, queue, device, train_loader,
                      train_selected_loader, optimizer, epoch, features, selected_pair_th, selected_examples)

        features = compute_features(args, model, train_loader, test_loader)
        if (epoch >= args.warmup_epoch):
            print('######## Pair-wise selection ########')
            selected_examples, selected_pair_th = pair_selection(args, model, device, train_loader, test_loader, epoch,
                                                                 features)

        print('Epoch time: {:.2f} seconds\n'.format(time.time() - st))
        log_file.flush()

        print('######## Test ########')
        index = test_eval(args, model, device, test_loader)
        valid_f1 = index.f1
        print(f'Epoch {epoch}: acc:{index.acc}, pre:{index.pre}, rec:{index.rec}, f1:{index.f1}')

        if valid_f1 > best_valid_f1:
            best_valid_f1 = deepcopy(valid_f1)
            early_stop_counter = 0
            best_f1_model_state = deepcopy(model.state_dict())
            best_f1_optimizer_state = deepcopy(optimizer.state_dict())
        else:
            early_stop_counter += 1

        if epoch % 10 == 0 or epoch == 1:
            save_model(best_f1_model_state, best_f1_optimizer_state, args, epoch, exp_path + "/Sel-CL_model.pth")
            np.save(res_path + '/' + 'selected_examples_train.npy', selected_examples.data.cpu().numpy())

        if early_stop_counter > args.patience:
            print(f'Validation f1 score did not improve for {args.patience} epochs. Early stopping.')
            break


if __name__ == "__main__":
    args = parse_args()
    main(args)
