from utils import log
import time
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras.layers as layers

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")


import ipca_v2
import AwA2_helper
from utils.test_utils import arg_parser, prepare_data, get_measures
from utils.test_utils import run_concept
from sklearn.linear_model import LogisticRegressionCV
from utils.models import prepare_InceptionV3
from utils.plot_utils import plot_stats


def get_prediction(predict_model, features):
    logits = predict_model(features)
    if len(logits) == 3: # when TopicModel_V2 is used as predict_model
        logits = logits[1]
    return logits

def iterate_data_msp(data_loader, feature_model, predict_model):
    # Hendrycks et al. ICLR'17

    softmax = layers.Activation('softmax')
    #logits = predict_model(feature_model.predict(data_loader))
    logits = get_prediction(predict_model, feature_model.predict(data_loader))
    conf = tf.math.reduce_max(softmax(logits), axis=1)

    return conf.numpy()

def iterate_data_odin(data_loader, feature_model, predict_model, epsilon, temper, num_classes, logger):

    confs = []
    loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    for b in range(data_loader.__len__()):
        x = data_loader.next()
        if len(x) == 2:
            x = tf.cast(x[0], tf.float32)
        else:
            x = tf.cast(x, tf.float32)
        #outputs = predict_model(feature_model(x))
        outputs = get_prediction(predict_model, feature_model(x))

        maxIndexTemp = np.argmax(outputs, axis=1)
        maxIndexTemp = tf.keras.utils.to_categorical(maxIndexTemp, num_classes)
        labels = tf.convert_to_tensor(maxIndexTemp, dtype=tf.float32)
        outputs = outputs / temper
        with tf.GradientTape() as tape:
            tape.watch(x)
            #output = predict_model(feature_model(x))
            output = get_prediction(predict_model, feature_model(x))
            loss = loss_object(labels, output)

        gradient = tape.gradient(loss, x)
        signed_grad = tf.sign(gradient)
        # Adding small perturbations to images
        tempInputs = x - epsilon*signed_grad
        #outputs = predict_model(feature_model(tempInputs))
        outputs = get_prediction(predict_model, feature_model(tempInputs))
        outputs = outputs / temper
        # Calculating the confidence after adding perturbations
        nnOutputs = outputs.numpy()
        nnOutputs = nnOutputs - np.max(nnOutputs, axis=1, keepdims=True)
        nnOutputs = np.exp(nnOutputs) / np.sum(np.exp(nnOutputs), axis=1, keepdims=True)

        confs.extend(np.max(nnOutputs, axis=1))

        #if b % 100 == 0:
        #    logger.info('{} batches processed'.format(b))

    return np.array(confs)

def iterate_data_energy(data_loader, feature_model, predict_model, temper):

    #logits = predict_model(feature_model.predict(data_loader))
    logits = get_prediction(predict_model, feature_model.predict(data_loader))
    Ec = -temper * tf.reduce_logsumexp(logits / temper, axis=1)
    conf = Ec.numpy()
    #conf = -tf.make_ndarray(tf.make_tensor_proto(Ec))

    return -conf


def iterate_data_kl_div(data_loader, model, out=False):
    softmax = layers.Activation('softmax')

    probs, labels = [], []
    for b in range(data_loader.__len__()):
        if not out: # for in-distribution data
            x, y = data_loader.next()
        else: # for out-of-distribution data, labels not needed
            x = data_loader.next()

        #logits = predict_model(feature_model(x))
        logits = get_prediction(predict_model, feature_model(x))
        prob = softmax(logits).numpy()
        #prob = tf.make_ndarray(tf.make_tensor_proto(prob))

        probs.extend(prob)

        if not out:
            labels.extend(np.argmax(y, axis=1))
    return np.array(probs), np.array(labels)

def kl(p, q):
    """Kullback-Leibler divergence D(P || Q) for discrete distributions
    Parameters
    ----------
    p, q : array-like, dtype=float, shape=n
    Discrete probability distributions.
    """
    # p = np.asarray(p, dtype=np.float)
    # q = np.asarray(q, dtype=np.float)

    return np.sum(np.where(p != 0, p * np.log(p / q), 0))


def run_eval(feature_model, predict_model, in_loader, out_loader, logger, args, num_classes, flag_eval=True):

    logger.info("Running test...")

    if np.char.lower(args.score) == 'msp':
        logger.info("Processing in-distribution data...")
        in_scores = iterate_data_msp(in_loader, feature_model, predict_model)
        logger.info("Processing out-of-distribution data...")
        out_scores = iterate_data_msp(out_loader, feature_model, predict_model)
    elif np.char.lower(args.score) == 'odin':
        logger.info("Processing in-distribution data...")
        in_scores = iterate_data_odin(in_loader, feature_model, predict_model, args.epsilon_odin, args.temperature_odin, num_classes, logger)
        logger.info("Processing out-of-distribution data...")
        out_scores = iterate_data_odin(out_loader, feature_model, predict_model, args.epsilon_odin, args.temperature_odin, num_classes, logger)
    elif np.char.lower(args.score) == 'energy':
        logger.info("Processing in-distribution data...")
        in_scores = iterate_data_energy(in_loader, feature_model, predict_model, args.temperature_energy)
        logger.info("Processing out-of-distribution data...")
        out_scores = iterate_data_energy(out_loader, feature_model, predict_model, args.temperature_energy)
    elif args.score == 'KL_Div':
        logger.info("Processing in-distribution data...")
        in_dist_logits, in_labels = iterate_data_kl_div(in_loader, feature_model, predict_model, out=False)
        logger.info("Processing out-of-distribution data...")
        out_dist_logits, _ = iterate_data_kl_div(out_loader, feature_model, predict_model, out=True)

        class_mean_logits = []
        for c in range(num_classes):
            selected_idx = (in_labels == c)
            selected_logits = in_dist_logits[selected_idx]
            class_mean_logits.append(np.mean(selected_logits, axis=0))
        class_mean_logits = np.array(class_mean_logits)
        #print(class_mean_logits)

        logger.info("Compute distance for in-distribution data...")
        in_scores = []
        for i, logit in enumerate(in_dist_logits):
            if i % 1000 == 0:
                logger.info('{} samples processed...'.format(i))
            min_div = float('inf')
            for c_mean in class_mean_logits:
                cur_div = kl(logit, c_mean)
                if cur_div < min_div:
                    min_div = cur_div
            in_scores.append(-min_div)
            #in_scores.append(min_div)
        in_scores = np.array(in_scores)

        logger.info("Compute distance for out-of-distribution data...")
        out_scores = []
        for i, logit in enumerate(out_dist_logits):
            if i % 1000 == 0:
                logger.info('{} samples processed...'.format(i))
            min_div = float('inf')
            for c_mean in class_mean_logits:
                cur_div = kl(logit, c_mean)
                if cur_div < min_div:
                    min_div = cur_div
            out_scores.append(-min_div)
            #out_scores.append(min_div)
        out_scores = np.array(out_scores)
        #print(in_scores)
        #print(out_scores)
    else:
        raise ValueError("Unknown score type {}".format(args.score))

    in_examples = in_scores #.reshape((-1, 1))
    out_examples = out_scores #.reshape((-1, 1))

    if flag_eval:
        auroc, aupr_in, aupr_out, fpr95, thres95 = get_measures(in_examples[:,None], out_examples[:,None])
        logger.info('============Results for {} on {}============'.format(args.score, args.out_data))
        logger.info('AUROC: {}'.format(auroc))
        logger.info('AUPR (In): {}'.format(aupr_in))
        logger.info('AUPR (Out): {}'.format(aupr_out))
        logger.info('FPR95: {}'.format(fpr95))
        logger.info('THRES: {}'.format(thres95))
        
        return in_examples, out_examples, thres95, auroc
    else:
        return in_examples, out_examples
"""
def main(args):
    _=0
    logger = log.setup_logger(args)

    in_loader, out_loader = prepare_data(args, logger)

    logger.info(f"Loading model from {args.model_path}")
    # load trained_model
    #model = prepare_InceptionV3(modelpath=args.model_path, input_size=(32,32), pretrain=True, return_model=True)
    #model.evaluate(in_loader) # loss: 2.0356, accuracy: 0.9213

    _=None
    feature_model, predict_model = AwA2_helper.load_model_inception_new(_, _, input_size=(32,32), pretrain=True, n_gpus=1, modelname=args.model_path)
    feature = feature_model.predict(in_loader)
    topic_model = ipca_v2.TopicModel(feature[:2], n_concept=70, thres=0.2, predict=predict_model)
    _ = topic_model(feature[:2]) # call the subclassed model first
    topic_model.load_weights('results/Animals_with_Attributes2/latest_topic_AwA2.h5', by_name=True)
    logits_ = topic_model(feature, training=False)
    label = np.argmax(np.load('data/Animals_with_Attributes2/y_test.npy'), axis=1)
    #logits=predict_model(feature_model.predict(in_loader))
    acc = np.sum(label == np.argmax(logits_, axis=1)) / len(label)
    print('===================================')
    print('accuracy with original features: {}'.format(acc))

    start_time = time.time()
    run_eval(feature_model, predict_model, in_loader, out_loader, logger, args, num_classes=50)
    end_time = time.time()

    logger.info("Total running time: {}".format(end_time - start_time))
    logger.flush()

if __name__ == "__main__":
    parser = arg_parser()

    parser.add_argument('--score', choices=['MSP', 'ODIN', 'Energy', 'Mahalanobis', 'KL_Div'], default='MSP')
    parser.add_argument('--temperature_odin', default=1000, type=int,
                        help='temperature scaling for odin')
    parser.add_argument('--epsilon_odin', default=0.0, type=float,
                        help='perturbation magnitude for odin')
    parser.add_argument('--temperature_energy', default=1, type=int,
                        help='temperature scaling for energy')

    main(parser.parse_args())
"""
