import torch
import numpy as np
import argparse
import os
import time
from copy import deepcopy

from torch import optim

import random
import sys

sys.path.append('../../../utils/')
sys.path.append('../')
sys.path.append('../utils')

from utils.test_eval import test_eval
from utils.utils_plus import *
from utils.other_utils import *
from utils.models.preact_resnet import *

from torch.utils.data import DataLoader, TensorDataset
from utils.blurred_datasets import seeg_datasets
from default_config import get_exp_dict, window_time_dict, slide_time_dict

num_threads = '32'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads


def parse_args():
    parser = argparse.ArgumentParser(description='command for the first train')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--batch_size', type=int, default=128, help='#images in each mini-batch')
    parser.add_argument('--test_batch_size', type=int, default=100, help='#images in each mini-batch')
    parser.add_argument('--cuda_dev', type=int, default=0, help='GPU to select')
    parser.add_argument('--epoch', type=int, default=50, help='training epoches')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--noise_type', default='symmetric', help='noise type of the dataset')
    parser.add_argument('--train_root', default='./dataset', help='root for train data')
    parser.add_argument('--out', type=str, default='./output', help='Directory of the output')
    parser.add_argument('--alpha_m', type=float, default=1.0, help='Beta distribution parameter for mixup')
    parser.add_argument('--download', type=bool, default=False, help='download dataset')
    parser.add_argument('--network', type=str, default='PR18', help='Network architecture')
    parser.add_argument('--seed_initialization', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--seed_dataset', type=int, default=42, help='random seed (default: 1)')
    parser.add_argument('--M', action='append', type=int, default=[], help="Milestones for the LR sheduler")
    parser.add_argument('--experiment_name', type=str, default='Proof',
                        help='name of the experiment (for the output files)')
    parser.add_argument('--dataset', type=str, default='CIFAR-10', help='CIFAR-10, CIFAR-100')
    parser.add_argument('--initial_epoch', type=int, default=1, help="Star training at initial_epoch")
    parser.add_argument('--low_dim', type=int, default=128, help='Size of contrastive learning embedding')
    parser.add_argument('--headType', type=str, default="Linear", help='Linear, NonLinear')
    parser.add_argument('--startLabelCorrection', type=int, default=30, help='Epoch to start label correction')
    parser.add_argument('--ReInitializeClassif', type=int, default=0, help='Enable predictive label correction')
    parser.add_argument('--DA', type=str, default="simple", help='Choose simple or complex data augmentation')

    parser.add_argument('--database_save_dir', type=str, default='./data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='fNIRS_2',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--exp_id', type=int, default=1,
                        help='The experimental id.')
    parser.add_argument('--num_classes', type=int, default=2, help='Number of in-distribution classes')
    parser.add_argument('--noise_ratio', type=float, default=0.4, help='percent of noise')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')
    parser.add_argument('--patience', type=int, default=10, help='patience fot early stopping')

    args = parser.parse_args()

    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    return args


def blurred_data_config(args, clean_idx):
    _, _, x_train, y_train, x_valid, y_valid, x_test, y_test, n_class = seeg_datasets(args)
    args.num_classes = n_class

    train_dataset = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long(),
                                  torch.from_numpy(np.arange(len(y_train))))
    valid_dataset = TensorDataset(torch.from_numpy(x_valid).float(), torch.from_numpy(y_valid).long())
    test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())

    # Get only detected clean samples
    selected_indices = np.where(clean_idx == 1)[0]
    new_train_dataset = TensorDataset(
        train_dataset.tensors[0][selected_indices],
        train_dataset.tensors[1][selected_indices],
        torch.from_numpy(np.arange(len(train_dataset.tensors[1][selected_indices])))
    )

    train_loader = DataLoader(new_train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False,
                              num_workers=8, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False,
                              num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False,
                             num_workers=8, pin_memory=True)

    return train_loader, valid_loader, train_dataset, x_train.shape[-2]


def main(args):
    # best_ac only record the best top1_ac for validation set.
    best_ac = 0.0
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_dev)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed_initialization)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed_initialization)  # GPU seed

    random.seed(args.seed_initialization)  # python seed for image transformation

    exp_path = os.path.join(args.out, 'Sel-CL_experiment', args.data_name, str(int(args.noise_ratio * 100)),
                            f'exp{args.exp_id}')
    res_path = exp_path

    clean_idx = np.load(res_path + "/selected_examples_train.npy")

    # train_loader, test_loader, trainset = data_config(args, transform_train, transform_test, clean_idx)
    train_loader, test_loader, trainset, input_size = blurred_data_config(args, clean_idx)

    model = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=args.low_dim,
                           head=args.headType).to(device)
    try:
        model.load_state_dict(torch.load(exp_path + "/Sel-CL_model.pth")['model'], strict=False)
    except:
        model.load_state_dict(torch.load(exp_path + "/Sel-CL_model.pth"), strict=False)

    exp_path = exp_path + "/plus/"
    res_path = res_path + "/plus/"
    if not os.path.isdir(res_path):
        os.makedirs(res_path)

    if not os.path.isdir(exp_path):
        os.makedirs(exp_path)

    __console__ = sys.stdout
    log_file = open(res_path + "results.log", 'a')
    sys.stdout = log_file
    print(args)

    # if args.ReInitializeClassif==1:
    #     model.linear2 = nn.Linear(512, args.num_classes).to(device)

    milestones = args.M

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    best_valid_f1 = -1
    early_stop_counter = 0
    best_f1_model_state = deepcopy(model.state_dict())
    best_f1_optimizer_state = deepcopy(optimizer.state_dict())

    for epoch in range(args.initial_epoch, args.epoch + 1):
        st = time.time()
        print("=================>    ", args.experiment_name, args.noise_ratio)
        scheduler.step()
        train_mixup(args, model, device, train_loader, optimizer, epoch)
        print('Epoch time: {:.2f} seconds\n'.format(time.time() - st))

        index = test_eval(args, model, device, test_loader)
        valid_f1 = index.f1
        print(f'Epoch {epoch}: acc:{index.acc}, pre:{index.pre}, rec:{index.rec}, f1:{index.f1}')

        if valid_f1 > best_valid_f1:
            best_valid_f1 = deepcopy(valid_f1)
            early_stop_counter = 0
            best_f1_model_state = deepcopy(model.state_dict())
            best_f1_optimizer_state = deepcopy(optimizer.state_dict())
        else:
            early_stop_counter += 1

        if epoch % 10 == 0 or epoch == 1:
            save_model(best_f1_model_state, best_f1_optimizer_state, args, epoch, exp_path + "/Sel-CL_plus_model.pth")

        if early_stop_counter > args.patience:
            print(f'Validation f1 score did not improve for {args.patience} epochs. Early stopping.')
            break


if __name__ == "__main__":
    args = parse_args()
    main(args)
