# Copyright (c) Alibaba Group
import argparse
import torch
import torchvision.datasets as datasets
import torch.nn.functional as F
import os
import math
import numpy as np
import random
import logging
import faiss
from gmpy2 import random_state
from sympy.abc import alpha
from tqdm import tqdm
from sklearn.cluster import KMeans

from utils.detection_util import print_measures, get_and_print_results
from utils.file_ops import save_as_dataframe, setup_log
from utils.plot_util import plot_distribution

import scipy.optimize as sopt

parser = argparse.ArgumentParser(description='OnZeta for ImageNet')

parser.add_argument('--seed', default=2, type=int, help="random seed") #42

# hyper-parameters
parser.add_argument('--ngroup', type=int, default=10,
                    help='number of grouping')

# hyper-parameters

parser.add_argument('--alpha', type=float, default=0.3,
                    help='temperature parameter')

parser.add_argument('--beta', type=float, default=0.8,
                    help='temperature parameter')

parser.add_argument('--t1', type=float, default=0.08,
                    help='temperature parameter')
parser.add_argument('--t2', type=float, default=0.02,
                    help='temperature parameter')
parser.add_argument('--t3', type=float, default=50,
                    help='temperature parameter')

parser.add_argument(
    "--mc_local_number",
    type=int,
    default=2,
    help="Number of random local crops.",
)

parser.add_argument('--score', default='ours', type=str, choices=[
    'MCM', 'energy', 'max-logit', 'entropy', 'var', 'maha', 'ours'], help='score options')


# end

torch.cuda.set_device(1)

to_np = lambda x: x.data.cpu().numpy()
concat = lambda x: np.concatenate(x, axis=0)
relu = torch.nn.ReLU()

in_dataset = 'ImageNet'
out_datasets = ['iNaturalist','SUN', 'places365', 'dtd']



def setup_log(args):
    log = logging.getLogger(__name__)
    formatter = logging.Formatter('%(asctime)s : %(message)s')
    fileHandler = logging.FileHandler(".\ood_eval_info.log", mode='w')
    fileHandler.setFormatter(formatter)
    streamHandler = logging.StreamHandler()
    streamHandler.setFormatter(formatter)
    log.setLevel(logging.DEBUG)
    log.addHandler(fileHandler)
    log.addHandler(streamHandler)
    log.debug(f"#########eval_ood############")
    return log

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def get_score(image_features, text_features_pos, text_features_neg, args, log):

    score = 0

    text_features_neg_merged = torch.reshape(text_features_neg,
                                             (text_features_neg.size(0) * text_features_neg.size(1), -1))

    global_features = image_features["global"].cuda()
    local_features = image_features["local"].cuda()

    for i in range(args.mc_local_number):

        text_features_merged = torch.cat((text_features_pos, text_features_neg_merged), dim=0)

        local_output_merged = local_features[i] @ text_features_merged.T / args.t3

        w_xp_y = F.softmax(local_output_merged, dim=1)

        global_local_features = args.alpha * global_features + (1 - args.alpha) * local_features[i]

        global_local_output_pos = global_local_features @ text_features_pos.T / args.t2

        global_local_output_merged = global_local_features @ text_features_merged.T / args.t2

        z_xp_y = torch.sum(w_xp_y * torch.exp(global_local_output_merged), dim=1, keepdim=True)

        score_x_y_I_xp = torch.exp(global_local_output_pos) / z_xp_y

        score_i_xp_y = 0

        for j in range(args.ngroup):

            text_features = torch.cat((text_features_pos, text_features_neg[j]), dim=0)

            local_output_pos = local_features[i] @ text_features_pos.T / args.t1

            local_output = local_features[i] @ text_features.T / args.t1

            z_xp = torch.sum(torch.exp(local_output), dim=1, keepdim=True)

            score_i_xp_y_j = torch.exp(local_output_pos) / z_xp

            score_i_xp_y = score_i_xp_y + score_i_xp_y_j / args.ngroup

        score_i = torch.log(score_i_xp_y) + torch.log(score_x_y_I_xp)

        score = score + score_i / args.mc_local_number

    score = args.beta * np.log(np.sum(np.exp(to_np(score) / args.beta), axis=1))

    return score



def main():

    args = parser.parse_args()

    setup_seed(args.seed)

    log = setup_log(args)


    dump_dict = torch.load('./selected_neg_labels/neg_dump_new2.pth')
    text_features_neg = dump_dict['neg_emb_selected'].cuda().to(torch.float32)
    text_features_pos = dump_dict['pos_emb'].cuda().to(torch.float32)

    print(text_features_neg.size())

    drop = text_features_neg.shape[0] % args.ngroup

    random_permute = True

    if drop > 0:
        text_features_neg = text_features_neg[:-drop, :]
    if random_permute:
        idx = torch.randperm(text_features_neg.size(0)).cuda()
        text_features_neg = text_features_neg[idx]
    text_features_neg = torch.reshape(text_features_neg, (args.ngroup, -1, text_features_neg.size(1)))

    ID_features = {}

    ID_global_features = np.load('./CLIP-features/'+in_dataset+'/global_features.npy')
    ID_features['global'] = torch.from_numpy(ID_global_features).to(torch.float32)

    ID_local_features = []
    for i in ["even", "odd"]:
        ID_local_features.append(np.expand_dims(np.load('./CLIP-features/'+in_dataset+'/'+i+'_features.npy'), axis=0))
    ID_local_features = np.concatenate(ID_local_features, axis=0)
    ID_features['local'] = torch.from_numpy(ID_local_features).to(torch.float32)

    in_scores = get_score(ID_features, text_features_pos, text_features_neg, args, log)

    auroc_list, aupr_list, fpr_list = [], [], []

    for out_dataset in out_datasets:

        print(f"Evaluting OOD dataset {out_dataset}")

        OOD_features = {}

        OOD_global_features = np.load('./CLIP-features/' + out_dataset + '/global_features.npy')
        OOD_features['global'] = torch.from_numpy(OOD_global_features).to(torch.float32)

        OOD_local_features = []
        for i in ["even", "odd"]:
            OOD_local_features.append(
                np.expand_dims(np.load('./CLIP-features/' + out_dataset + '/' + i + '_features.npy'), axis=0))
        OOD_local_features = np.concatenate(OOD_local_features, axis=0)
        OOD_features['local'] = torch.from_numpy(OOD_local_features).to(torch.float32)

        out_scores = get_score(OOD_features, text_features_pos, text_features_neg, args, log)

        get_and_print_results(args, log, -in_scores, -out_scores,
                              auroc_list, aupr_list, fpr_list)


    log.debug('\n\nMean Test Results')
    print_measures(log, np.mean(auroc_list), np.mean(aupr_list),
                   np.mean(fpr_list), method_name=args.score)









if __name__ == '__main__':
    main()

