from __future__ import print_function
import sys

sys.path.append('../../')
sys.path.append('../../pipeline')
sys.path.append('../../utils')

import torch
import random
import os
import argparse
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset

from dataloader_blurred import test_dataset
from PreResNet_blurred import *
from default_config import get_exp_dict, window_time_dict, slide_time_dict
from pipeline.cb_evaluation_api import class_evaluation

## Arguments to pass 
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch_size', default=64, type=int, help='train batchsize')
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--noise_mode', default='sym')
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=30, type=float, help='weight for unsupervised loss')
parser.add_argument('--lambda_c', default=0.025, type=float, help='weight for contrastive loss')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--d_u', default=0.7, type=float)
parser.add_argument('--tau', default=5, type=float, help='filtering coefficient')
parser.add_argument('--metric', type=str, default='JSD', help='Comparison Metric')
parser.add_argument('--seed', default=123)
parser.add_argument('--cuda_dev', default=0, type=int)
parser.add_argument('--resume', default=False, type=bool, help='Resume from the warmup checkpoint')
# parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--data_path', default='./data/cifar10', type=str, help='path to dataset')
parser.add_argument('--dataset', default='cifar10', type=str)

parser.add_argument('--out_path', type=str, default='./output', help='Output path for result')
parser.add_argument('--database_save_dir', type=str, default='./data/CL_database/',
                    help='Should give a path to load the database of one patient.')
parser.add_argument('--data_name', type=str, default='fNIRS_2',
                    help='Should give the name of the database [SEEG, fNIRS_2, Sleep].')
parser.add_argument('--exp_id', type=int, default=3,
                    help='The experimental id.')
parser.add_argument('--num_classes', type=int, default=2, help='Number of in-distribution classes')
parser.add_argument('--noise_ratio', type=float, default=0.4, help='percent of noise')
parser.add_argument('--window_time', type=float, default=1,
                    help='The seconds of every sample segment.')
parser.add_argument('--slide_time', type=float, default=0.5,
                    help='The sliding seconds between two sample segments.')
parser.add_argument('--num_level', type=int, default=5,
                    help='The number of levels.')
parser.add_argument('--patience', type=int, default=10, help='patience fot early stopping')

args = parser.parse_args()

exp_dict = get_exp_dict(args.data_name)
exp_patient_list = exp_dict[args.exp_id]
args.train_patient_list = exp_patient_list[0]
args.valid_patient_list = exp_patient_list[1]
args.test_patient_list = exp_patient_list[2]

args.window_time = window_time_dict[args.data_name]
args.slide_time = slide_time_dict[args.data_name]

## GPU Setup
torch.cuda.set_device(args.cuda_dev)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# result path
res_path = os.path.join(args.out_path, 'UNICON_experiment_2', args.data_name, str(int(args.noise_ratio * 100)),
                        f'exp{args.exp_id}')


def test(args, net1, net2, test_loader):
    net1.eval()
    net2.eval()

    y_true = torch.tensor([], dtype=torch.long)
    y_pred = torch.tensor([], dtype=torch.long)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            _, outputs1 = net1(inputs)
            _, outputs2 = net2(inputs)
            outputs = outputs1 + outputs2
            _, predicted = torch.max(outputs, 1)

            y_true = torch.cat((y_true, targets.view(-1).cpu()))
            y_pred = torch.cat((y_pred, predicted.view(-1).cpu()))

    index = class_evaluation(
        y_true.numpy(),
        y_pred.numpy(),
        args.num_classes
    )

    return index


## call the dataloader
x_test, y_test = test_dataset(args)
testdataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())
testloader = DataLoader(testdataset, batch_size=100, shuffle=False, drop_last=False,
                        num_workers=8, pin_memory=True)
input_size = x_test.shape[-2]
args.num_classes = len(np.unique(y_test))

## load the model
net1 = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=128)
net2 = PreActResNet18(input_size=input_size, num_classes=args.num_classes, low_dim=128)
net1 = net1.cuda()
net2 = net2.cuda()

try:
    net1.load_state_dict(torch.load(f'{res_path}/Net1.pth')['net'])
except:
    net1.load_state_dict(torch.load(f'{res_path}/Net1_warmup.pth')['net'])

try:
    net2.load_state_dict(torch.load(f'{res_path}/Net2.pth')['net'])
except:
    net2.load_state_dict(torch.load(f'{res_path}/Net2_warmup.pth')['net'])

index = test(args, net1, net2, testloader)

df_results = pd.DataFrame()
res_dict = {}
res_dict['acc'] = index.acc
res_dict['pre'] = index.pre
res_dict['rec'] = index.rec
res_dict['f1'] = index.f1

df_results = df_results.append(res_dict, ignore_index=True)

df_results.to_csv(os.path.join(res_path, 'test_results.csv'), sep=',', index=False, header=True)

print(f'Successfully save test results of {args.data_name}_exp{args.exp_id}_r{int(args.noise_ratio * 100)}')
