import torch
import argparse
import os
import random
import sys

import pandas as pd

sys.path.append('../../../utils/')
sys.path.append('../')
sys.path.append('../utils')

from utils.test_eval import test_eval
from utils.utils_plus import *
from utils.other_utils import *
from utils.models.preact_resnet import *

from torch.utils.data import DataLoader, TensorDataset
from utils.blurred_datasets import seeg_datasets
from default_config import get_exp_dict, window_time_dict, slide_time_dict

num_threads = '32'
torch.set_num_threads(int(num_threads))
os.environ['OMP_NUM_THREADS'] = num_threads
os.environ['OPENBLAS_NUM_THREADS'] = num_threads
os.environ['MKL_NUM_THREADS'] = num_threads
os.environ['VECLIB_MAXIMUM_THREADS'] = num_threads
os.environ['NUMEXPR_NUM_THREADS'] = num_threads


def parse_args():
    parser = argparse.ArgumentParser(description='command for the first train')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--batch_size', type=int, default=128, help='#images in each mini-batch')
    parser.add_argument('--test_batch_size', type=int, default=100, help='#images in each mini-batch')
    parser.add_argument('--cuda_dev', type=int, default=0, help='GPU to select')
    parser.add_argument('--epoch', type=int, default=50, help='training epoches')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--noise_type', default='symmetric', help='noise type of the dataset')
    parser.add_argument('--train_root', default='./dataset', help='root for train data')
    parser.add_argument('--out', type=str, default='./output', help='Directory of the output')
    parser.add_argument('--alpha_m', type=float, default=1.0, help='Beta distribution parameter for mixup')
    parser.add_argument('--download', type=bool, default=False, help='download dataset')
    parser.add_argument('--network', type=str, default='PR18', help='Network architecture')
    parser.add_argument('--seed_initialization', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--seed_dataset', type=int, default=42, help='random seed (default: 1)')
    parser.add_argument('--M', action='append', type=int, default=[], help="Milestones for the LR sheduler")
    parser.add_argument('--experiment_name', type=str, default='Proof',
                        help='name of the experiment (for the output files)')
    parser.add_argument('--dataset', type=str, default='CIFAR-10', help='CIFAR-10, CIFAR-100')
    parser.add_argument('--initial_epoch', type=int, default=1, help="Star training at initial_epoch")
    parser.add_argument('--low_dim', type=int, default=128, help='Size of contrastive learning embedding')
    parser.add_argument('--headType', type=str, default="Linear", help='Linear, NonLinear')
    parser.add_argument('--startLabelCorrection', type=int, default=30, help='Epoch to start label correction')
    parser.add_argument('--ReInitializeClassif', type=int, default=0, help='Enable predictive label correction')
    parser.add_argument('--DA', type=str, default="simple", help='Choose simple or complex data augmentation')

    parser.add_argument('--database_save_dir', type=str, default='./data/CL_database/',
                        help='Should give a path to load the database of one patient.')
    parser.add_argument('--data_name', type=str, default='fNIRS_2',
                        help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
    parser.add_argument('--exp_id', type=int, default=1,
                        help='The experimental id.')
    parser.add_argument('--num_classes', type=int, default=2, help='Number of in-distribution classes')
    parser.add_argument('--noise_ratio', type=float, default=0.4, help='percent of noise')
    parser.add_argument('--window_time', type=float, default=1,
                        help='The seconds of every sample segment.')
    parser.add_argument('--slide_time', type=float, default=0.5,
                        help='The sliding seconds between two sample segments.')
    parser.add_argument('--num_level', type=int, default=5,
                        help='The number of levels.')
    parser.add_argument('--patience', type=int, default=10, help='patience fot early stopping')

    args = parser.parse_args()

    exp_dict = get_exp_dict(args.data_name)
    exp_patient_list = exp_dict[args.exp_id]
    args.train_patient_list = exp_patient_list[0]
    args.valid_patient_list = exp_patient_list[1]
    args.test_patient_list = exp_patient_list[2]

    args.window_time = window_time_dict[args.data_name]
    args.slide_time = slide_time_dict[args.data_name]

    return args


def blurred_data_config(args):
    _, _, x_train, y_train, x_valid, y_valid, x_test, y_test, n_class = seeg_datasets(args)
    args.num_classes = n_class

    test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False,
                             num_workers=8, pin_memory=True)

    return test_loader, x_train.shape[-2]


def main(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_dev)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.backends.cudnn.deterministic = True  # fix the GPU to deterministic mode
    torch.manual_seed(args.seed_initialization)  # CPU seed
    if device == "cuda":
        torch.cuda.manual_seed_all(args.seed_initialization)  # GPU seed

    random.seed(args.seed_initialization)  # python seed for image transformation

    model_path = os.path.join(args.out, 'Sel-CL_experiment', args.data_name, str(int(args.noise_ratio * 100)),
                              f'exp{args.exp_id}', 'plus')
    res_path = model_path

    test_loader, input_size = blurred_data_config(args)

    model = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=args.low_dim,
                           head=args.headType).to(device)
    try:
        model.load_state_dict(torch.load(model_path + "/Sel-CL_plus_model.pth")['model'], strict=False)
    except:
        model.load_state_dict(torch.load(model_path + "/Sel-CL_plus_model.pth"), strict=False)

    print(args)

    index = test_eval(args, model, device, test_loader)

    df_results = pd.DataFrame()

    res_dict = {}
    res_dict['acc'] = index.acc
    res_dict['pre'] = index.pre
    res_dict['rec'] = index.rec
    res_dict['f1'] = index.f1

    df_results = df_results.append(res_dict, ignore_index=True)

    df_results.to_csv(os.path.join(res_path, 'test_results.csv'), sep=',', index=False, header=True)

    print(f'Successfully save test results of {args.data_name}_exp{args.exp_id}_r{int(args.noise_ratio * 100)}')


if __name__ == '__main__':
    args = parse_args()
    main(args)
