#!/usr/bin/env python3.7
import os
import time
from util.args_loader import get_args
from util import metrics
import torch
import faiss
import numpy as np
import joblib
import eval_ood

torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = get_args()

#os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
print(f'==> Replicas: {args.replicas}')
cache_name = f"cache/{args.in_dataset}_{args.replicas}_train_{args.name}_in_alllayers.joblib"
feat_log, score_log, label_log = joblib.load(cache_name)
feat_log, score_log = feat_log.T.astype(np.float32), score_log.T.astype(np.float32)
class_num = score_log.shape[1]

cache_name = f"cache/{args.in_dataset}_{args.replicas}_val_{args.name}_in_alllayers.joblib"
feat_log_val, score_log_val, label_log_val = joblib.load(cache_name)
feat_log_val, score_log_val = feat_log_val.T.astype(np.float32), score_log_val.T.astype(np.float32)

for ood_dataset in args.out_datasets:
    print(f"==> {ood_dataset}")
    cache_name = f"cache/{ood_dataset}vs{args.in_dataset}_{args.replicas}_{args.name}_out_alllayers.joblib"
    ood_feat_log, ood_score_log = joblib.load(cache_name)
    ood_feat_log, ood_score_log = ood_feat_log.T.astype(np.float32), ood_score_log.T.astype(np.float32)

    model_outputs = {'train': {'feas': None, 'logits': None, 'labels': None},
                 'id': {'feas': None, 'logits': None, 'labels': None},
                 'ood': {'feas': None, 'logits': None, 'labels': None}}

    model_outputs['train']['feas'] = torch.from_numpy(feat_log)
    model_outputs['id']['feas']= torch.from_numpy(feat_log_val)
    model_outputs['ood']['feas'] = torch.from_numpy(ood_feat_log)

    model_outputs['train']['logits'] = torch.from_numpy(score_log)
    model_outputs['id']['logits']= torch.from_numpy(score_log_val)
    model_outputs['ood']['logits'] = torch.from_numpy(ood_score_log)

    model_outputs['train']['labels'] = torch.from_numpy(label_log)
    model_outputs['id']['labels']= torch.from_numpy(label_log_val)

    # Evaluate detectors
    scores_set_raw = eval_ood.eval_detectors(model_outputs, args.in_dataset, ood_dataset, args)
    print()


