#!/usr/bin/env python3.7
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import numpy as np
from models import *
import pickle
from misc.utils_python import mkdir, import_yaml_config, save_dict, load_dict
from ood_detectors.factory import create_ood_detector
from eval_assets import save_performance

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scores_set_raw = {} 

def eval_detectors(model_outputs, in_dataset, out_dataset, args):
    all_detectors = ["energy", "nnguide", "msp", "maxlogit", "ssd", "mahalanobis", "knn"] #, "knn_plus"]
    scores_set = {}
    accs = {}
    all_results = []
    for selected_detector in all_detectors:
        # Setup OOD detector
        args = import_yaml_config(args, f"./configs/detector/{selected_detector}.yaml")
        if selected_detector == 'knn' or selected_detector == 'nnguide':
            if in_dataset == 'pathmnist':
                args.detector['knn_k'] = 9
            elif in_dataset == 'CIFAR-10':
                args.detector['knn_k'] = 10
        
        ood_detector = create_ood_detector(selected_detector)
        ood_detector.setup(args, model_outputs['train'])
        id_scores = ood_detector.infer(model_outputs['id'])
        ood_scores = ood_detector.infer(model_outputs['ood'])
        
        # Calculate mean:
        reshaped_id_score = id_scores.reshape(-1, args.replicas).numpy()
        reshaped_ood_score = ood_scores.reshape(-1, args.replicas).numpy()
        id_scores = np.median(reshaped_id_score, axis=1)
        ood_scores = np.median(reshaped_ood_score, axis=1)
        scores = np.concatenate((id_scores, ood_scores), axis=0)
        # Detection labels
        detection_labels = np.concatenate((np.ones(len(id_scores)), np.zeros(len(ood_scores))), axis=0)
        scores_set[selected_detector] = scores

        # Model accuracy
        id_logits = model_outputs['id']['logits']
        preds_id = torch.max(id_logits, dim=1)[1]
        acc = (preds_id == model_outputs['id']['labels']).float().mean().numpy()
        accs[selected_detector] = acc*100

    save_performance(scores_set, detection_labels, accs, f"./logs/{in_dataset}_vs_ood_{out_dataset}_{args.replicas}.csv")
    
    return scores_set_raw

