"""==================================================================================================="""
################### LIBRARIES ###################
### Basic Libraries
import warnings
warnings.filterwarnings("ignore")

import os, numpy as np, argparse
import time, random, datetime, pathlib

from tqdm import tqdm

import parameters    as par

"""==================================================================================================="""
################### INPUT ARGUMENTS ###################
parser = argparse.ArgumentParser()

parser = par.basic_training_parameters(parser)
parser = par.batch_creation_parameters(parser)
parser = par.batchmining_specific_parameters(parser)
parser = par.loss_specific_parameters(parser)
parser = par.log_parameters(parser)

##### Read in parameters
opt = parser.parse_known_args()[0]

import torch, torch.nn as nn
from torch.nn import DataParallel
import architectures as archs
import datasampler   as dsamplers
import datasets      as datasets
import criteria      as criteria
import metrics       as metrics
import batchminer    as bmine
import evaluation    as eval
from utilities import misc
from utilities import logger


from datasets import dsets, label_sets
import pickle
import architectures as archs

##### Print parameters
logger.print_args(opt)

#################### SEEDS FOR REPROD. #####################
torch.backends.cudnn.deterministic=True; np.random.seed(opt.seed); random.seed(opt.seed)
torch.manual_seed(opt.seed); torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed)

if len(opt.gpu) > 0:
    num_devices = torch.cuda.device_count()
    device_ids = opt.gpu
    if num_devices < len(device_ids):
        raise Exception(
            '#available gpu : {} < --device_ids : {}'
                .format(num_devices, len(device_ids)))
    opt.device = torch.device('cuda:{}'.format(device_ids[0]))
    opt.device_ids = device_ids
else:
    opt.device = torch.device('cpu')

final_save_path = os.path.join(opt.save_path, 'model.pt')

opt_path = os.path.join(opt.save_path, 'hypa.pkl')

with open(opt_path, "rb") as f: 
    opt = pickle.load(f)
model_files = torch.load(final_save_path)

model = archs.select(opt.arch, opt).to(opt.device)
model.load_state_dict(model_files['model_state_dict'])
model.eval()

dataloaders = {}
dset = dsets.get(opt.dataset, None)
all_sets = label_sets.get_all_label_sets(opt.dataset)
for c, ls in enumerate(all_sets):
    dss = dset.Give(opt, ls, opt.source_path)
    dataloaders[c] = {
        'validation': torch.utils.data.DataLoader(dss['validation'], num_workers=opt.kernels, batch_size=opt.bs,shuffle=False),
        'testing': torch.utils.data.DataLoader(dss['testing'], num_workers=opt.kernels, batch_size=opt.bs,shuffle=False)
    }

results = {'meta': all_sets, 'metrics': {}, 'train_ls': opt.label_set}
opt.rho_spectrum_embed_dim = opt.embed_dim
if opt.exclusive:
    metric_computer = metrics.MetricComputer(opt.evaluation_metrics, opt)
else:
    metric_computer = metrics.MetricComputer(opt.evaluation_metrics, opt)

for ls in dataloaders:
    for split in dataloaders[ls]:
        results['metrics'][f'{ls}_{split}'] = metric_computer.compute_standard(opt, model, dataloaders[ls][split], opt.evaltypes, opt.device)

pickle.dump(
    results, open(os.path.join(opt.save_path, 'agg_res.pkl'), 'wb')
)

with open(os.path.join(opt.save_path, 'done_eval'), 'w') as f:
    f.write('done')
