import os
import joblib
import argparse

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_distances

import faiss
from tqdm import tqdm

from models.resnet import *
from utils.detection_util import set_ood_loader_ImageNet, obtain_feature_from_loader, set_ood_loader_small, get_and_print_results, obtain_logit_from_loader
from utils.util import set_loader_ImageNet, set_loader_small, set_model
from utils.display_results import  plot_distribution, print_measures, save_as_dataframe, plot_scatter
import models.hyptorch as hypb
from models.ash import get_score
from utils import scores
from utils.attack import Attack


def process_args():
    parser = argparse.ArgumentParser(description='Evaluates OOD Detector',formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--in_dataset', default="CIFAR-100", type=str, help='in-distribution dataset') 
    parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size')
    parser.add_argument('--epoch', default ="500", type=str, help='which epoch to test')
    parser.add_argument('--gpu', default=4,  type=int, help='which GPU to use')
    parser.add_argument('--loss', default = 'cider', type=str, choices = ['supcon', 'cider'],
                    help='loss of experiment')
    parser.add_argument('--name', type=str, default = '')
    parser.add_argument('--id_loc', default="datasets/CIFAR100", type=str, help='location of in-distribution dataset')
    parser.add_argument('--ood_loc', default="datasets/small_OOD_dataset", type=str, help='location of ood datasets')

    parser.add_argument('--score', default='maha', type=str, help='score options: knn|maha')
    parser.add_argument('--K', default=300, type=int, help='K in KNN score')
    parser.add_argument('--subset', default=False, type=bool, help='whether to use subset for KNN')
    parser.add_argument('--multiplier', default=1, type=float,
                     help='norm multipler to help solve numerical issues with precision matrix')
    parser.add_argument('--normalize', action='store_true',
                        help='normalize feat embeddings'
                       )
    parser.add_argument('--model', default='resnet34', type=str, help='model architecture')
    parser.add_argument('--embedding_dim', default = 512, type=int, help='encoder feature dim')
    parser.add_argument('--feat_dim', default = 128, type=int, help='head feature dim')
    parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head')
    parser.add_argument('--out_as_pos', action='store_true', help='if OOD data defined as positive class.')
    parser.add_argument('--T', default=1, type=float, help='temperature: energy|Odin')
    parser.add_argument('--main_dir', default="./", help='working space')
    parser.add_argument('--ash_method', default="", help='apply ash layer with percentage')
    parser.add_argument('--c_ball', type=float, default=1.0,
                            help='curvature of the Poincare ball')
    parser.add_argument('--train_origin', type=bool, default=False,
                            help='train origin of the Poincare ball')
    parser.add_argument('--train_c', type=bool, default=False,
                            help='train curative of the Poincare ball')
    parser.add_argument('--neg_shot', type=int, default=0,
                            help='observe number of ood samples as negative anchors')
    parser.add_argument('--attack', type=str, default="",
                            help='attack type')
    args = parser.parse_args()
    print(args)


    args.ckpt = f"{args.main_dir}/checkpoints/{args.in_dataset}/{args.name}/checkpoint_{args.epoch}.pth.tar"
    

    if args.in_dataset == "CIFAR-10":
        args.n_cls = 10
    elif args.in_dataset in ["CIFAR-100", 'ImageNet-100']:
        args.n_cls = 100

    return args


def get_features(args, net, train_loader, test_loader, attacker):
    feat_dir= f"{args.main_dir}/feat/{args.in_dataset}/{args.name}/{args.epoch}"
    if not os.path.exists(feat_dir):
        os.makedirs(feat_dir)
        ftrain = obtain_feature_from_loader(args, net, train_loader, 
            num_batches=None, normalize=True
        )
        with open(f'{feat_dir}/feat.npy', 'wb') as f:
            np.save(f, ftrain)
    else:
        with open(f'{feat_dir}/feat.npy', 'rb') as f:
            ftrain = np.load(f)
    ftest = obtain_feature_from_loader(args, net, test_loader, 
        num_batches=None,
        attacker=attacker,
        normalize=True
    )#, embedding_dim=args.feat_dim)
    return ftrain, ftest


def set_up(args): 
    args.log_directory = (f"{args.main_dir}/results/"
        f"{args.in_dataset}/{args.name}/{args.loss}/epoch_{args.epoch}/{args.score}"
        f"/{args.attack}"
    )
    os.makedirs(args.log_directory, exist_ok=True)

    try: 
        pretrained_dict= torch.load(args.ckpt,  map_location='cpu')['state_dict']
    except: 
        print("loading model as SupCE format")
        pretrained_dict= torch.load(args.ckpt,  map_location='cpu')['model']
    net = set_model(args)
    net.load_state_dict(pretrained_dict)
    net.eval()

    # adversarial attack
    atk = Attack(net, args.attack, n_class=args.n_cls) if args.attack else None

    if args.in_dataset == 'ImageNet-100':
        train_loader, test_loader = set_loader_ImageNet(args, 
            eval=True, 
            # attacker=atk
        )
    else:
        train_loader, test_loader = set_loader_small(args, 
            eval=True,
            # attacker=atk
        )

    return train_loader, test_loader, net, atk

def main(args):
    train_loader, test_loader, net, atk = set_up(args)
    ood_num_examples = len(test_loader.dataset) 
    num_batches = ood_num_examples // args.batch_size

    if args.ash_method:
        setattr(net, 'ash_method', args.ash_method)

    classwise_mean, precision = scores.get_mean_prec(args, net, 
        train_loader, 
        attacker=atk
    )
    if args.score == "knn":
        classwise_mean = classwise_mean.cpu().detach().numpy()
        ftrain, ftest = get_features(args, net, train_loader, test_loader, atk)

        # train distribution to prototypes
        # train_logits, train_targets = obtain_logit_from_loader(args, net, 
        #     train_loader, num_batches, normalize=False
        # )

        # in_logits, in_targets = obtain_logit_from_loader(args, net, test_loader, num_batches, normalize=True)
        # pred = np.argmax(in_logits, axis=1)
        # print("In ACC", np.mean(pred == in_targets))

        in_score_all = []
        knn = scores.KNN(args, ftrain[:,:512].copy())
        # knn2 = scores.KNN(args, ftrain[:,512:640].copy(), "inner")
        # knn3 = scores.KNN(args, ftrain[:,640:].copy(), "inner")
        knn_p = scores.KNN(args, classwise_mean)

        #train_score = knn.get_knn_score(ftrain, args.n_cls)

        in_score = knn.get_knn_score(ftest)

        # joblib.dump(in_logits, f"{args.in_dataset}_logits.pkl")
        # joblib.dump(targets, f"{args.in_dataset}_targets.pkl")
    elif args.score == "ash":
        in_score = scores.get_ash_score(args, net, test_loader)
    elif args.score == "grad":
        in_score = scores.get_grad_score(args, net, test_loader)
    else:
        in_score = scores.get_dist_score(args, net, test_loader, 
            classwise_mean, precision, in_dist=True, attacker=atk
        )

    print('preprocessing ID finished')

    if args.in_dataset == 'ImageNet-100':
        out_datasets = ['SUN', 'places365', 'dtd', 'iNaturalist']
    else: 
        #out_datasets = ["cifar100c"] 
        out_datasets = ["SVHN", 'places365', "LSUN", "LSUN_resize", "iSUN", 'dtd']
        #out_datasets = ['birds', 'flowers', 'coil-100']

    auroc_list, aupr_list, fpr_list = [], [], []
    for out_dataset in out_datasets:
        print(f"Evaluting OOD dataset {out_dataset}")
        if args.in_dataset == 'ImageNet-100':
            ood_loader = set_ood_loader_ImageNet(args, out_dataset, attacker=atk)
        else: 
            ood_loader = set_ood_loader_small(args, out_dataset, attacker=atk)

        if args.score == "knn":
            ood_feat = obtain_feature_from_loader(args, net, 
                ood_loader, num_batches, attacker=atk, normalize=True
            )
            #ood_logits, targets = obtain_logit_from_loader(args, net, ood_loader, num_batches, normalize=True)
            #joblib.dump(ood_logits, f"{out_dataset}_logits.pkl")
            #joblib.dump(targets, f"{out_dataset}_targets.pkl")
            print(f'preprocessing OOD {out_dataset} finished')
            out_score = knn.get_knn_score(ood_feat)
            
            if args.neg_shot:
                in_score_ = in_score
                neg_sample = ood_feat[np.random.choice(len(ood_feat), args.neg_shot)]
                neg_sample = np.mean(neg_sample, 0, keepdims=True)
                in_score = in_score_ - knn_p.get_neg_score(ftest, in_logits, neg_sample)
                out_score -= knn_p.get_neg_score(ood_feat, ood_logits, neg_sample)

        elif args.score == "ash":
            out_score = scores.get_ash_score(args, net, ood_loader)
        elif args.score == "grad":
            out_score = scores.get_grad_score(args, net, ood_loader)
        else:
            # classwise_mean, precision = scores.get_mean_prec(args, net, train_loader)
            out_score = scores.get_dist_score(args, net, ood_loader, 
                classwise_mean, precision, in_dist=True
            )


        print(in_score[:3], out_score[:3])  
        
        #in_pred = np.argmax(in_logits, axis=1)
        #print("acc", sum(in_pred == in_targets), len(in_targets))
        #in_score[in_pred == in_targets] = -1 
        #out_pred = np.argmax(ood_logits, axis=1)
        #`print("acc", sum(out_pred == targets), len(targets))
        #out_score[out_pred == targets] = 10 
        joblib.dump(out_score, f"{out_dataset}_out_score.pkl")
        joblib.dump(in_score, f"CIFAR100_in_score.pkl")

        plot_distribution(args, in_score, out_score, out_dataset)
        get_and_print_results(args, in_score, out_score, auroc_list, aupr_list, fpr_list, log = None)
        
    print("AVG")
    print_measures(None, np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr_list), method_name=args.name)
    save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list)


if __name__ == '__main__':
    args = process_args()
    #prform OOD detection
    main(args)




