##################################
# Acknowledgment:
# Part of this code is adopted from https://github.com/kahnchana/opl
# Part of this code is adopted from https://www.kaggle.com/code/yiweiwangau/cifar-100-resnet-pytorch-75-17-accuracy
# Part of this code is adopted from https://github.com/deeplearning-wisc/cider
##################################


import os
import argparse

import numpy as np
import torch
import torch.nn.functional as F

import faiss
from tqdm import tqdm

 
from utils.detection_util import set_ood_loader_ImageNet, obtain_feature_from_loader, set_ood_loader_small, get_and_print_results
from utils.util import set_loader_ImageNet, set_loader_small, set_model
from utils.display_results import  plot_distribution, print_measures, save_as_dataframe

from utils.ICR import ICR
def process_args():
    parser = argparse.ArgumentParser(description='Evaluates OOD Detector',formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--in_dataset', default="CIFAR-100", type=str, help='in-distribution dataset') 
    parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size')
    parser.add_argument('--epoch', default ="500", type=str, help='which epoch to test')
    parser.add_argument('--gpu', default=0,  type=int, help='which GPU to use')
    parser.add_argument('--loss', default = 'cider', type=str, choices = ['supcon', 'cider'],
                    help='loss of experiment')
    parser.add_argument('--name', type=str, default = 'ckpt_c100')
    parser.add_argument('--id_loc', default="datasets/CIFAR100", type=str, help='location of in-distribution dataset')
    parser.add_argument('--ood_loc', default="datasets/small_OOD_dataset", type=str, help='location of ood datasets')

    parser.add_argument('--score', default='knn', type=str, help='score options: knn|maha|msp|odin|energy')
    parser.add_argument('--K', default=300, type=int, help='K in KNN score')
    parser.add_argument('--subset', default=False, type=bool, help='whether to use subset for KNN')
    # parser.add_argument('--norm_pe', type = bool, default = True, help='if normalize penultimate layer')
    parser.add_argument('--multiplier', default=1, type=float,
                      help='norm multipler to help solve numerical issues with precision matrix')
    parser.add_argument('--model', default='resnet34', type=str, help='model architecture')
    parser.add_argument('--embedding_dim', default = 512, type=int, help='encoder feature dim')
    parser.add_argument('--feat_dim', default = 128, type=int, help='head feature dim')
    parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head')
    parser.add_argument('--normalize', action='store_true',
                        help='normalize feat embeddings')
    parser.add_argument('--out_as_pos', action='store_true', help='if OOD data defined as positive class.')
    parser.add_argument('--T', default=1000, type=float, help='temperature: energy|Odin')
    args = parser.parse_args()
    print(args)
    
###############################################################################

# # cifar 10
    # args.epoch = 500
    # args.model = "resnet18"
    # args.head = "mlp"
    # args.gpu = 0
    # args.score = "knn"
    # args.K = 100
    # args.in_dataset = "CIFAR-10"
    # args.id_loc = "datasets/CIFAR10"
    # args.ood_loc = "datasets/small_OOD_dataset"
    # args.name = "ckpt_c10"

    # ClassLabels_list =  ["airplane",
    # "automobile",
    # "bird",
    # "cat",
    # "deer",
    # "dog",
    # "frog",
    # "horse",
    # "ship",
    # "truck"]
    
    # args.ClassLabels_list = ClassLabels_list
    # args.iter_ICR = 2

###############################################################################

# cifar 100
    
    args.in_dataset = "CIFAR100"
    args.epoch = 500
    args.gpu = 0
    args.score = "knn"
    args.K = 300
    # args.normalize = True
    
    ClassLabels_list = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
          'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
          'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
          'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
          'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle',
          'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
          'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon',
          'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
          'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone',
          'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree',
          'wolf', 'woman', 'worm']
    
    args.ClassLabels_list = ClassLabels_list
    args.iter_ICR = 1
    
    return args

###############################################################################

def set_up(args): 
    args.log_directory = f"results/{args.in_dataset}/{args.score}"
    if not os.path.exists(args.log_directory):
        os.makedirs(args.log_directory)
 


def loadData(feat_dir, test = False):
    with open(f'{feat_dir}/train_feat.npy', 'rb') as f:
        cider_feat_train = np.load(f)
    with open(f'{feat_dir}/train_lab.npy', 'rb') as f:
        cider_labels_train = np.load(f)
    if test:
        with open(f'{feat_dir}/test_feat.npy', 'rb') as f:
            cider_feat_test = np.load(f)
        with open(f'{feat_dir}/test_lab.npy', 'rb') as f:
            cider_labels_test = np.load(f)
        return cider_feat_train, cider_labels_train, cider_feat_test, cider_labels_test
    return cider_feat_train, cider_labels_train
        
 
from collections import defaultdict

def dictionaryEmb(args, fdata_train, flabels_train):
    dictOfEmb = defaultdict(list)
    for idx, Class in enumerate(args.ClassLabels_list):
        class_idx = np.where(flabels_train == idx)[0]
        dictOfEmb[Class] = np.take(fdata_train, class_idx, axis=0)
        
    all_wordsVocab_class = [("word-" + str(a+1), args.ClassLabels_list[int(b)]) for a, b in zip(range(len(fdata_train)), flabels_train)]
    ClassVocab = defaultdict(list)
    for class_word, classlabel in all_wordsVocab_class:
        ClassVocab[classlabel].append(class_word)
       
    return dictOfEmb, ClassVocab
    
def get_features_ICR_Features(fdata, fdata_train, flabels_train, fname = None):
    
    feat_dir= "ICR_features/Iter1"
    if not os.path.exists(feat_dir):
        os.makedirs(feat_dir)
    fname_feat = fname + "_feat.npy"
 
    
    if not os.path.isfile(f'{feat_dir}/{fname_feat}'): 
        
#########################################
#Train ICR Features    
   
        dictOfEmb, ClassVocab = dictionaryEmb(args, fdata_train, flabels_train)   
        fdata = ICR(fdata, dictOfEmb, args.ClassLabels_list, ClassVocab, fdata_train, args.iter_ICR, mode = fname)
#########################################
        with open(f'{feat_dir}/{fname_feat}', 'wb') as f:
            np.save(f, fdata)         
    else:
        with open(f'{feat_dir}/{fname_feat}', 'rb') as f:
            fdata = np.load(f)
            
    return fdata.astype("float32")

###############################################################################
base_path = "OOD/GetFeatures/CIDER_Feat/"
dtype = 'torch.cuda.FloatTensor'
device = "cuda"
batch_size = 512
n_cls = 100
def ClassMean_Precision(feat_train, feat_label):
    feat_label = feat_label[:, None]
    opl_features_labels = np.hstack((feat_train, feat_label))
    classwise_mean = [np.mean(opl_features_labels[opl_features_labels[:,-1]==k][:,:-1], axis = 0)[np.newaxis, :] for k in np.unique(opl_features_labels[:,-1])]
    classwise_mean  = np.concatenate( classwise_mean, axis=0 )
    classwise_mean /= np.linalg.norm(classwise_mean, axis=1)[:, np.newaxis]
    classwise_mean = torch.from_numpy(classwise_mean).type(dtype).to(device)
    
    all_features = torch.from_numpy(feat_train).type(dtype).to(device)
    cov = torch.cov(all_features.T.double()) 
    # cov = cov + 1e-7*torch.eye(all_features.shape[1]).cuda()
    precision = torch.linalg.inv(cov).float()
    return classwise_mean, precision

def compute_Mahalanobis_score(feat_test,  classwise_mean, precision):
    Mahalanobis_score_all = []
    for batch_index in range(0, len(feat_test), batch_size):
        features = feat_test[batch_index:batch_index+batch_size]
        features  =  torch.from_numpy(  features ).type(dtype).to(device)
        for i in range(n_cls):
            class_mean = classwise_mean[i]
            zero_f = features - class_mean
            Mahalanobis_dist = -0.5*torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag()
            if i == 0:
                Mahalanobis_score = Mahalanobis_dist.view(-1,1)
            else:
                Mahalanobis_score = torch.cat((Mahalanobis_score, Mahalanobis_dist.view(-1,1)), 1)      
        Mahalanobis_score, _ = torch.max(Mahalanobis_score, dim=1)
        Mahalanobis_score_all.extend(-Mahalanobis_score.cpu().numpy())
    return np.asarray(Mahalanobis_score_all, dtype=np.float32)

def main(args):
     
    set_up(args)
    
    ID_feat_path = base_path + args.in_dataset
    ftrain, ftrain_labels, ftest, ftest_labels = loadData(ID_feat_path,  test = True)
    ood_num_examples = len(ftest) 
    num_batches = ood_num_examples // args.batch_size
    

     
    # # ftrain_CIDER, ftrain_labels_CIDER = ftrain.copy(), ftrain_labels.copy()
    # # ##############  
    # # # lenClass = len(set(ftrain_labels))
    # # # args.ClassLabels_list = webpage[:lenClass]
    # # ftrain = get_features_ICR_Features(ftrain, ftrain_CIDER, ftrain_labels_CIDER,  fname = "train_ICR") 
    # # ftest = get_features_ICR_Features(ftest, ftrain_CIDER, ftrain_labels_CIDER,  fname = "test_ICR")
    # ##############
    # index = faiss.IndexFlatL2(ftrain.shape[1])
    # index.add(ftrain)
    # index_bad = index
    # D, X = index_bad.search(ftest, args.K, )
    # in_score = D[:,-1]
    
    
    classwise_mean, precision = ClassMean_Precision(ftrain, ftrain_labels)
    in_score = compute_Mahalanobis_score(ftest,  classwise_mean, precision)


    out_datasets =  [ 'SVHN', 'places365', 'LSUN', 'iSUN', 'dtd' ]  #

    auroc_list, aupr_list, fpr_list = [], [], []
    for out_dataset in out_datasets:
        print(f"Evaluting OOD dataset {out_dataset}")
 
        OOD_feat_path = base_path + out_dataset
        ood_feat,_ = loadData(OOD_feat_path,  test = False)
        # ############## 
        # # # # lenClass = len(set(ood_feat_labels))
        # # # # args.ClassLabels_list = webpage[:lenClass]
        # # ood_feat = get_features_ICR_Features(ood_feat,ftrain_CIDER, ftrain_labels_CIDER, fname = "OOD_ICR_" + out_dataset)
        # ############## 
        # print(f'preprocessing OOD {out_dataset} finished')
        # D, _ = index_bad.search(ood_feat,args.K)
        # out_score = D[:,-1]
 
        out_score = compute_Mahalanobis_score(ood_feat,  classwise_mean, precision)
        
        print(in_score[:3], out_score[:3])  
        plot_distribution(args, in_score, out_score, out_dataset)
        get_and_print_results(args, in_score, out_score, auroc_list, aupr_list, fpr_list, log = None)
        
    print("AVG")
    print_measures(None, np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr_list), method_name=args.name)
    save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list)


if __name__ == '__main__':
    args = process_args()
    #prform OOD detection
    main(args)





















 