import argparse
import os
import sys
import json
from tkinter.tix import Tree
from tqdm import tqdm
import pickle
from PIL import Image


import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import numpy as np
import torchvision.datasets as dset
from metrics.label_metrics import METRICS
from metrics.ood_metrics import OOD_METRICS
from utils import prepare_dset_test, check_dir
import torchvision.transforms as transforms
import torchvision
from datasets import ImagenetNoise
from networks.dul_resnet import resnet34, DulLoss
from metrics import crl_metrics
import crl_utils


use_cuda = torch.cuda.is_available()
# maha_intermediate_dict = np.load('./ssl/maha_dict_clip32_imagenet_512.npy', allow_pickle='TRUE')
# class_cov_invs = maha_intermediate_dict.item()['class_cov_invs']
# class_means = maha_intermediate_dict.item()['class_means']
# cov_invs = maha_intermediate_dict.item()['cov_inv']
# means = maha_intermediate_dict.item()['mean']
# cov = np.linalg.inv(cov_invs)
# means = torch.from_numpy(means.reshape(1, 512)).cuda()
# var = torch.from_numpy(np.diagonal(cov).reshape(1, 512)).cuda()
def evaluate(dataloader, metric_fn, eval_acc=False):
    if args.method != 'prior_net':
        model.eval()
    model.training = False
    if args.method == "mc_dropout":
        for m in model.modules():
            if m.__class__.__name__.startswith('Dropout'):
                m.train()
    
    confidence_score = []
    # harmonic_mean_score = []
    if eval_acc:
        correct = 0
        total = 0

    with torch.no_grad():
        criterion = DulLoss().cuda()
        for batch_idx, (inputs, targets) in enumerate(tqdm(dataloader)):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            inputs, targets = Variable(inputs), Variable(targets)
            
            # Eval Step
            if args.method == "deep_ens":
                outputs = model(inputs)
                # logits, _ = torch.max(outputs,1)
                # outputs = F.softmax(outputs)

                conf = metric_fn(outputs.data)
                confidence_score.append(conf.cpu().numpy())
            if args.method == "dul":
                mu, logvar, embedding, outputs = model(inputs)
                loss = criterion(outputs, targets, mu, logvar, means, var)
                harmonic_mean = inputs.size(0)/torch.sum((1./logvar.exp()),dim=1)
                # print(harmonic_mean)
                confidence_score.append(harmonic_mean.cpu().numpy())


            # For Sanity Check
            if eval_acc:
                _, predicted = torch.max(outputs.data, -1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()
                # print(targets)
                # print(correct)
            
        if eval_acc:
            print(correct/total)
    return confidence_score

def evaluate_dul(dataloader, metric_fn, eval_acc=False):
    img_indexs_list = []
    maha_dis_list = []
    for _ in range(args.num_classes):
        img_indexs_list.append([])
        maha_dis_list.append([])
    maha_indexs_dic = dict()
    if args.method != 'prior_net':
        model.eval()
    model.training = False
    
    confidence_score = []
    # harmonic_mean_score = []

    with torch.no_grad():
        for batch_idx, (id, inputs, targets) in enumerate(tqdm(dataloader)):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            inputs, targets = Variable(inputs), Variable(targets)
            
            # Eval Step
            mu, logvar, embedding, outputs = model(inputs)
            harmonic_mean = inputs.size(0)/torch.sum((1./logvar.exp()),dim=1)

            confidence_score.append(harmonic_mean.cpu().numpy())
            for index, j in enumerate(targets.cpu().data.numpy().tolist()):
                maha_dis_list[j].append(harmonic_mean.cpu().numpy()[index])
                img_indexs_list[j].append(id.cpu().data.numpy()[index])
            
        
        maha_index_dic = []
        for j in range(args.num_classes):
            # maha_dis_list[j] = np.concatenate(maha_dis_list[j],axis=0).squeeze()
            # img_indexs_list[j] = np.concatenate(img_indexs_list[j],axis=0).squeeze()
            maha_indexs_dic = dict(zip(img_indexs_list[j], maha_dis_list[j]))
            maha_indexs_dic_sorted = sorted(maha_indexs_dic.items(), key = lambda kv:[kv[1], kv[0]],reverse=True)[-10:]
            maha_index_dic.append(maha_indexs_dic_sorted)
        # print(maha_index_dic[0],maha_index_dic[9])
        sorted_indexs_class = []
        for j in range(args.num_classes):
            sorted_ids = []
            for k in range(len(maha_index_dic[j])):
                sorted_ids.append(maha_index_dic[j][k][0])
            sorted_indexs_class.append(sorted_ids)

        k = 1
        mean = [0.485,0.456,0.406]
        std = [0.229,0.224,0.225] 
        # for batch_idx, (id,(inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
        for batch_idx, (id,inputs,targets) in enumerate(dataloader):
            # img, tar_an = inputs.cpu().numpy(), targets.cpu().numpy()
            # if args.dataset != 'imagenet':
            #     inputs = up_sample(inputs)
            img = inputs
            tar_an = targets.data.numpy().tolist()
            # for j in (targets.data.numpy().tolist()):
                # print(j)
            for i, ids in enumerate(id.data.numpy().tolist()):
                if ids in sorted_indexs_class[tar_an[i]]:
                    # print(ids)
                    # print(sorted_indexs_class[tar_an[i]])
                    # print(i)
                    # return
                    # print(i,ids)
                    # print(k)
                    k +=1
                    index_plt = sorted_indexs_class[tar_an[i]].index(ids)
                    print('index:',index_plt)
                    img_num_arr = img[i].numpy()
                    # print(img_num_arr)
                    for c in range(len(mean)):
                        img_num_arr[c] = img_num_arr[c] * std[c] + mean[c]
                    img_num_arr = img_num_arr * 255
                    img_num_arr = np.transpose(img_num_arr, (1,2,0)) 
                    # img_num_arr = img[i].reshape(224, 224, 3)
                    # img = img[i].resize((32, 32),Image.ANTIALIAS)
                    rgb_img = Image.fromarray(np.uint8(img_num_arr))
                    # rgb_img = rgb_img.resize((32, 32),Image.ANTIALIAS)
                    check_dir('./figs/'+str(args.in_dataset)+'_dul_class_e_'+str(tar_an[i]))
                    base_dir = os.path.join('./figs/'+str(args.in_dataset)+'_dul_class_e_'+str(tar_an[i]), str(index_plt)+'_'+ str(ids) +'.jpg')
                    rgb_img.save(base_dir)
    return confidence_score
            

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='OOD Detection Evaluation')
    # General args
    parser.add_argument('--method', default='deep_ens', type=str, help='method')
    parser.add_argument('--in_dataset', default='cifar10', type=str, help='cifar10/cifar100')
    parser.add_argument('--ood_dataset', default='svhn', type=str, help='cifar10/cifar100/svhn/lsun')
    parser.add_argument('--distortion_name', default='gaussian_blur', type=str, help='only for imagenet-c')
    parser.add_argument('--model_path', type=str, default=None, help='Model trained with in_dataset')
    parser.add_argument('--args_path', type=str, default=None, help='Model args.')
    parser.add_argument('--label_metric', type=str, default='max_prob', help='Label metric to use (e.g., reject_score)')
    parser.add_argument('--tag', type=str, default="", help='Experiment tag')
    parser.add_argument('--save_conf', action='store_true', default=False)
    parser.add_argument('--ens_num', type=int, default=1, help='For ensemble-based methods, assign the number of model to locate the ckpt.')
    parser.add_argument('--mc_num', type=int, default=1, help='For mc-dropout, assign the sampling num to locate the ckpt.')
    parser.add_argument('--net_type', type=str, default="resnet34", help='Net type used in model, default: resnet18')
    parser.add_argument('--depth', default=4, type=int, help='depth of model')
    parser.add_argument('--widen_factor', default=4, type=int, help='width of model')
    parser.add_argument('--num_classes', default=1000, type=int)

    args = parser.parse_args()

    if args.in_dataset == "cifar10":
        num_classes = 10
    elif args.in_dataset == "cifar100":
        num_classes = 100

    # Loading datasets
    if args.in_dataset != 'imagenet':
        in_dataset = prepare_dset_test(args.in_dataset)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
        in_dataset = ImagenetNoise(
            train=False,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]),
            num_classes=args.num_classes
        )
        ood_dataset = dset.ImageFolder(
            # root='/data/ImageNet_C/' + args.distortion_name + '/' + str(5),
            root='/data/cuipeng/dataset/' + args.ood_dataset,
            # root='/data/cuipeng/dataset/SUN',
            transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    in_dataloader = torch.utils.data.DataLoader(in_dataset, batch_size=1024, shuffle=False)
    # ood_dataset = prepare_dset_test(args.ood_dataset)
    ood_dataloader = torch.utils.data.DataLoader(ood_dataset, batch_size=1024, shuffle=False)
    
    
    # Label Metric
    metric_fn = METRICS[args.label_metric]

    # Load Model
    try:
        # model = torchvision.models.resnet34(pretrained=False, num_classes=1000).cuda()
        model = torchvision.models.resnet34(pretrained=False, num_classes=1000).cuda()
        # model = resnet34(num_classes=1000).cuda()
        model_path = args.model_path
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['net'])
        # model_args = pickle.load(open(args.model_args_path, 'rb'))
        model_args = torch.load(args.args_path)
    except:
        raise FileNotFoundError("Can not find checkpoints.")
    
    # model = checkpoint['net'].cuda()
    model = torch.nn.DataParallel(model)

    res = {**vars(model_args), **vars(args)}

    # Sanity Check
    recorded_acc = checkpoint["acc"]
    print(f"Test Acc recorded in checkpoint: {recorded_acc}")
    cudnn.benchmark = True
    # cls_criterion = torch.nn.CrossEntropyLoss().cuda()
    # test_onehot = crl_utils.one_hot_encoding(in_dataset.targets)
    # test_label = in_dataset.targets
    # acc, aurc, eaurc, aupr, fpr, ece, nll, brier = crl_metrics.calc_metrics(in_dataloader,
                                                                        # test_label,
                                                                        # test_onehot,
    #                                                                     model,
    #                                                                     cls_criterion)
    # print(acc, aurc, eaurc, aupr, fpr, ece, nll, brier)
    # Calculate confidence score
    in_confidence_score = evaluate(in_dataloader, metric_fn, eval_acc=True)
    in_confidence_score = np.concatenate(in_confidence_score)
    # evaluate_dul(in_dataloader,metric_fn)
    ood_confidence_score = evaluate(ood_dataloader, metric_fn, eval_acc=True)
    ood_confidence_score = np.concatenate(ood_confidence_score)
    # import matplotlib.pyplot as plt
    # kwargs = dict(histtype='stepfilled', alpha=0.3, bins=40)
    
    # plt.hist(in_confidence_score, **kwargs)
    # plt.hist(ood_confidence_score, **kwargs)
    # plt.savefig('in_out_texture_maha_weight')
    
    # Save confidence path
    if args.save_conf:
        conf_path = os.path.join(os.getcwd(), 'ood_res', 'confidence_scores', args.method)
        check_dir(conf_path)
        np.savetxt(os.path.join(conf_path, f'confidence-In-{args.method}-id={args.in_dataset}-ood={args.ood_dataset}-lm-{args.label_metric}.txt'), in_confidence_score)
        np.savetxt(os.path.join(conf_path, f'confidence-Out-{args.method}-id={args.in_dataset}-ood={args.ood_dataset}-lm-{args.label_metric}.txt'), ood_confidence_score)

    print("Evaluating OOD Detection Perfermance on "+ args.ood_dataset + " using " + args.label_metric)
    scores = np.concatenate((in_confidence_score, ood_confidence_score), axis=0).astype(np.float128)
    if args.label_metric in ["max_prob", "kernel_distance"]:
        scores *= -1
    in_labels = np.zeros_like(in_confidence_score)
    out_labels = np.ones_like(ood_confidence_score)
    domain_labels = np.concatenate((in_labels, out_labels), axis=0)

    tpr95_score = OOD_METRICS["tpr95"](domain_labels, scores)
    auroc_score = OOD_METRICS["auroc"](domain_labels, scores)
    auprIn_score = OOD_METRICS["auprIn"](domain_labels, scores)
    auprOut_score = OOD_METRICS["auprOut"](domain_labels, scores)
    de_score = OOD_METRICS["detection_err"](domain_labels, scores)
    
    print("{:20}{:13.1f}% ".format("FPR at TPR 95%:", tpr95_score*100))
    print("{:20}{:13.1f}% ".format("Detection error:", de_score*100))
    print("{:20}{:13.1f}% ".format("AUROC:",auroc_score*100))
    print("{:20}{:13.1f}% ".format("AUPR In:",auprIn_score*100))
    print("{:20}{:13.1f}% ".format("AUPR Out:",auprOut_score*100))

    # res['tpr95'] = tpr95_score
    # res['auroc'] = auroc_score
    # res['auprIn'] = auprIn_score
    # res['auprOut'] = auprOut_score
    # res['detection_err'] = de_score

    # res_path = os.path.join(os.getcwd(), 'ood_res', 'result.json')
    
    # try:
    #     with open(res_path) as f:
    #         res_list = json.load(f)
    # except:
    #     res_list = []
    # res_list.append(res)

    # with open(res_path, "w+") as f:
    #     json.dump(res_list, f, indent=2)

    
