import argparse
import os
import logging
import copy

import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score

import torch
import torch.nn.functional as F

from src.core.utils import (
    get_logger, 
    get_net_builder, 
    over_write_args_from_file)
from src.fl_datasets.cv_datasets.ood import get_inlier, get_outlier
from src.algorithms import OursNet, OpenNet


def get_config():
    parser = argparse.ArgumentParser(description='Open-Set Semi-Supervised Learning Framework')
    parser.add_argument('--c', type=str, default='')
    
    args = parser.parse_args()
    over_write_args_from_file(args, args.c)
    return args


def eval_seen(model, full_loader, num_classes, device):
    total_loss = 0.0
    total_num = 0.0

    y_true = []
    y_in_true = []
    y_pred = []

    # OOD scores
    ood_scores_msp = []
    ood_scores_entropy = []
    ood_scores_energy = []
    ood_scores_logits_out = []  # logits_out 기반 점수
    ood_labels = []

    with torch.no_grad():
        for data in full_loader:
            x = data['x_lb']
            y = data['y_lb']

            if isinstance(x, dict):
                x = {k: v.to(device) for k, v in x.items()}
            else:
                x = x.to(device)
            y = y.to(device)

            outputs = model(x)
            logits = outputs['logits']

            # ---------- Classification ----------
            in_idx = torch.where(y < num_classes)[0]
            if len(in_idx) > 0:
                preds = torch.max(logits[in_idx], dim=-1)[1]
                loss = F.cross_entropy(logits[in_idx], y[in_idx], reduction='mean')
                total_loss += loss.item() * in_idx.shape[0]
                total_num += in_idx.shape[0]

                y_in_true.extend(y[in_idx].cpu().tolist())
                y_pred.extend(preds.cpu().tolist())

            y_true.extend(y.cpu().tolist())
            ood_labels.extend((y >= num_classes).int().cpu().tolist())

            # ---------- OOD scores ----------
            probs = F.softmax(logits, dim=1)
            max_prob, _ = probs.max(dim=1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
            energy = -torch.logsumexp(logits, dim=1)

            ood_scores_msp.extend((1.0 - max_prob).cpu().tolist())
            ood_scores_entropy.extend(entropy.cpu().tolist())
            ood_scores_energy.extend(energy.cpu().tolist())

            # logits_out 기반 score
            if "logits_out" in outputs:
                out_logits = outputs['logits_out']
                r = F.softmax(out_logits.view(out_logits.size(0), 2, -1), 1)
                tmp_range = torch.arange(0, out_logits.size(0)).long().to(device)
                pred_all = torch.max(logits, dim=-1)[1]
                unk_score = r[tmp_range, 0, pred_all]  # unknown 확률
                ood_scores_logits_out.extend(unk_score.detach().cpu().tolist())

    # Accuracy
    y_in_true = np.array(y_in_true)
    y_pred = np.array(y_pred)
    top1 = accuracy_score(y_in_true, y_pred) if len(y_in_true) > 0 else 0.0

    # AUROC
    try:
        auroc_msp = roc_auc_score(ood_labels, ood_scores_msp)
        auroc_entropy = roc_auc_score(ood_labels, ood_scores_entropy)
        auroc_energy = roc_auc_score(ood_labels, ood_scores_energy)
        auroc_logits_out = (
            roc_auc_score(ood_labels, ood_scores_logits_out)
            if len(ood_scores_logits_out) > 0
            else 0.0
        )
    except ValueError:
        auroc_msp = auroc_entropy = auroc_energy = auroc_logits_out = 0.0

    return top1, auroc_msp, auroc_entropy, auroc_energy, auroc_logits_out


def eval_unseen(model, in_loader, ood_loader, device):
    y_out_gt = []
    ood_scores_msp = []
    ood_scores_entropy = []
    ood_scores_energy = []
    ood_scores_logits_out = []  # logits_out 기반 점수
    ood_labels = []
        
    with torch.no_grad():
        # ---------- In-distribution data ----------
        for data in in_loader:
            x = data['x_lb']
            if isinstance(x, dict):
                x = {k: v.cuda(device) for k, v in x.items()}
            else:
                x = x.cuda(device)
            y = np.zeros(x.shape[0])
            y_out_gt.extend(y)

            outputs = model(x)
            logits = outputs['logits']
         
            # OOD scores
            probs = F.softmax(logits, dim=1)
            max_prob, _ = probs.max(dim=1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
            energy = -torch.logsumexp(logits, dim=1)

            ood_scores_msp.extend((1.0 - max_prob).cpu().tolist())
            ood_scores_entropy.extend(entropy.cpu().tolist())
            ood_scores_energy.extend(energy.cpu().tolist())

            # logits_out 기반 score
            if "logits_out" in outputs:
                out_logits = outputs['logits_out']
                r = F.softmax(out_logits.view(out_logits.size(0), 2, -1), 1)
                tmp_range = torch.arange(0, out_logits.size(0)).long().to(device)
                pred_all = torch.max(logits, dim=-1)[1]
                unk_score = r[tmp_range, 0, pred_all]  # unknown 확률
                ood_scores_logits_out.extend(unk_score.detach().cpu().tolist())
                   
        # ---------- OOD data ----------
        for x, y in ood_loader:
            x = x.cuda(device)
            y = np.ones(y.shape[0])
            y_out_gt.extend(y)
            
            outputs = model(x)
            logits = outputs['logits']
            
            ood_labels.extend(y)

            # OOD scores
            probs = F.softmax(logits, dim=1)
            max_prob, _ = probs.max(dim=1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
            energy = -torch.logsumexp(logits, dim=1)

            ood_scores_msp.extend((1.0 - max_prob).cpu().tolist())
            ood_scores_entropy.extend(entropy.cpu().tolist())
            ood_scores_energy.extend(energy.cpu().tolist())

            # logits_out 기반 score
            if "logits_out" in outputs:
                out_logits = outputs['logits_out']
                r = F.softmax(out_logits.view(out_logits.size(0), 2, -1), 1)
                tmp_range = torch.arange(0, out_logits.size(0)).long().to(device)
                pred_all = torch.max(logits, dim=-1)[1]
                unk_score = r[tmp_range, 0, pred_all]
                ood_scores_logits_out.extend(unk_score.detach().cpu().tolist())
            
    # ---------- AUROC ----------
    try:
        auroc_msp = roc_auc_score(y_out_gt, ood_scores_msp)
        auroc_entropy = roc_auc_score(y_out_gt, ood_scores_entropy)
        auroc_energy = roc_auc_score(y_out_gt, ood_scores_energy)
        auroc_logits_out = (
            roc_auc_score(y_out_gt, ood_scores_logits_out)
            if len(ood_scores_logits_out) > 0
            else 0.0
        )
    except ValueError:
        auroc_msp = auroc_entropy = auroc_energy = auroc_logits_out = 0.0
    
    return auroc_msp, auroc_entropy, auroc_energy, auroc_logits_out


def evaluation(args, model, ema=False):
    """
    evaluation function
    """
    model.eval()
    
    in_loader, full_loader = get_inlier(args, args.dataset, args.data_dir)
    
    # seen 평가 (logits_out 기반도 포함)
    in_acc, seen_auroc_msp, seen_auroc_entropy, seen_auroc_energy, seen_auroc_out = \
        eval_seen(model, full_loader, args.num_classes, args.gpu)

    unseen_auc_msp = {}
    unseen_auc_energy = {}
    unseen_auc_entropy = {}
    unseen_auc_out = {}
    
    num_unseen_auc = 0
    sum_unseen_msp_auc = 0
    sum_unseen_out_auc = 0
    
    for oods in args.ood_datasets:
        ood_loader = get_outlier(args, oods, args.data_dir)
        u_auroc_msp, u_auroc_entropy, u_auroc_energy, u_auroc_out = \
            eval_unseen(model, in_loader, ood_loader, args.gpu)

        unseen_auc_msp[oods] = u_auroc_msp
        unseen_auc_entropy[oods] = u_auroc_entropy
        unseen_auc_energy[oods] = u_auroc_energy
        unseen_auc_out[oods] = u_auroc_out

        sum_unseen_msp_auc += u_auroc_msp
        sum_unseen_out_auc += u_auroc_out
        num_unseen_auc += 1
        
    mean_unseen_msp_auc = sum_unseen_msp_auc / num_unseen_auc
    mean_unseen_out_auc = sum_unseen_out_auc / num_unseen_auc
    
    mean1 = (seen_auroc_msp + mean_unseen_msp_auc) / 2
    mean2 = (seen_auroc_out + mean_unseen_out_auc) / 2  # out 기반 mean
    
    return (
        in_acc,
        seen_auroc_msp, seen_auroc_entropy, seen_auroc_energy, seen_auroc_out,
        unseen_auc_msp, unseen_auc_energy, unseen_auc_entropy, unseen_auc_out,
        mean_unseen_msp_auc, mean_unseen_out_auc,
        mean1, mean2
    )


def load_ckpt(model, ckpt_path, mode='model'):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    origin_state_dict = ckpt[mode]
    
    new_state_dict = {}
    for key, v in origin_state_dict.items():    
        if key.startswith('module'):
            new_key = '.'.join(key.split('.')[1:])
        else:
            new_key = key
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict, strict=False)
    
    return model

    
def main(args):
    save_path = os.path.join(args.save_dir, args.save_name)
    logger = get_logger(args.save_name, save_path, 'INFO')
    
    _net_builder = get_net_builder(args.net, args.net_from_name)
    net_builder = _net_builder(num_classes=args.num_classes)
    
    if args.algorithm == "Ours":
        model_ = OursNet(base=net_builder, num_classes=args.num_classes, 
                         cls_hidden=args.cls_hidden,
                         proj_hidden=args.proj_hidden, 
                         proj_size=args.proj_size, 
                         out_hidden=args.out_hidden)
    elif args.algorithm == "ssb":
        model_ = OpenNet(base=net_builder, 
                         num_classes=args.num_classes, 
                         cls_hidden=args.cls_hidden,
                         out_hidden=args.out_hidden,
                         mlp=True)
    elif args.algorithm == "fixmatch":
        model_ = net_builder
    
    # EMA 모델 평가
    ema_model = load_ckpt(copy.deepcopy(model_), args.load_path, mode='ema_model')
    ema_model = ema_model.cuda(args.gpu) 
    
    logger.info("EMA MODEL loaded!")
    logger.info(f"ckpt path: {args.load_path}")

    logger.info("[!] EMA MODEL EVALUATION...")
    in_acc, s_auroc_msp, s_auroc_entropy, s_auroc_energy, s_auroc_out, \
        u_auc_msp, u_auc_energy, u_auc_entropy, u_auc_out, \
        mean_u_msp_auc, mean_u_out_auc, mean1, mean2 = evaluation(args, ema_model)
    
    logger.info(f">> EMA Closed-set ACC: {in_acc}")
    logger.info(f">> EMA SEEN AUROC - MSP: {s_auroc_msp}")
    logger.info(f">> EMA SEEN AUROC - ENERGY: {s_auroc_energy}")
    logger.info(f">> EMA SEEN AUROC - ENTROPY: {s_auroc_entropy}")
    logger.info(f">> EMA SEEN AUROC - OUT: {s_auroc_out}")
    logger.info("============================================")
    logger.info(f">> EMA UNSEEN AUROC - MSP: {u_auc_msp}")
    logger.info(f">> EMA UNSEEN AUROC - ENERGY: {u_auc_energy}")
    logger.info(f">> EMA UNSEEN AUROC - ENTROPY: {u_auc_entropy}")
    logger.info(f">> EMA UNSEEN AUROC - OUT: {u_auc_out}")
    logger.info("============================================")
    logger.info(f">> EMA MEAN UNSEEN AUROC (MSP): {mean_u_msp_auc}")
    logger.info(f">> EMA MEAN UNSEEN AUROC (OUT): {mean_u_out_auc}")
    logger.info(f">> EMA FINAL MEAN AUROC 1 (MSP): {mean1}")
    logger.info(f">> EMA FINAL MEAN AUROC 2 (OUT): {mean2}")
    
    # 모델 평가
    model = load_ckpt(copy.deepcopy(model_), args.load_path, mode='model')
    model = model.cuda(args.gpu)
    
    logger.info("MODEL loaded!")
    logger.info(f"ckpt path: {args.load_path}")

    logger.info("[!] MODEL EVALUATION...")
    in_acc, s_auroc_msp, s_auroc_entropy, s_auroc_energy, s_auroc_out, \
        u_auc_msp, u_auc_energy, u_auc_entropy, u_auc_out, \
        mean_u_msp_auc, mean_u_out_auc, mean1, mean2 = evaluation(args, model)
    
    logger.info(f">> Closed-set ACC: {in_acc}")
    logger.info(f">> SEEN AUROC - MSP: {s_auroc_msp}")
    logger.info(f">> SEEN AUROC - ENERGY: {s_auroc_energy}")
    logger.info(f">> SEEN AUROC - ENTROPY: {s_auroc_entropy}")
    logger.info(f">> SEEN AUROC - OUT: {s_auroc_out}")
    logger.info("============================================")
    logger.info(f">> UNSEEN AUROC - MSP: {u_auc_msp}")
    logger.info(f">> UNSEEN AUROC - ENERGY: {u_auc_energy}")
    logger.info(f">> UNSEEN AUROC - ENTROPY: {u_auc_entropy}")
    logger.info(f">> UNSEEN AUROC - OUT: {u_auc_out}")
    logger.info("============================================")
    logger.info(f">> MEAN UNSEEN AUROC (MSP): {mean_u_msp_auc}")
    logger.info(f">> MEAN UNSEEN AUROC (OUT): {mean_u_out_auc}")
    logger.info(f">> FINAL MEAN AUROC 1 (MSP): {mean1}")
    logger.info(f">> FINAL MEAN AUROC 2 (OUT): {mean2}")
    

if __name__ == "__main__":
    args = get_config()
    main(args)