import numpy as np

import tensorflow as tf
import tensorflow.keras.layers as layers



def get_prediction(predict_model, features):
    logits = predict_model(features)
    if len(logits) == 3: # when TopicModel is used as predict_model
        logits = logits[1]
    return logits


def iterate_data_msp(x, feature_model, predict_model, features=None):
    softmax = layers.Activation('softmax')
    logits = get_prediction(predict_model, feature_model(x) if features is None else features)
    conf = tf.math.reduce_max(softmax(logits), axis=1)

    return conf #.numpy() only in eager mode


def iterate_data_odin(x, feature_model, predict_model, epsilon, temper, num_classes, features=None):
    x = tf.cast(x, tf.float32)
    #outputs = get_prediction(predict_model, feature_model(x))
    outputs = get_prediction(predict_model, feature_model(x) if features is None else features)

    #confs = []
    loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

    #maxIndexTemp = np.argmax(outputs, axis=1)
    maxIndexTemp = tf.math.argmax(outputs, axis=1)
    #labels = tf.keras.utils.to_categorical(maxIndexTemp, num_classes)
    labels = tf.cast(tf.one_hot(tf.cast(maxIndexTemp, tf.int32), num_classes), dtype=maxIndexTemp.dtype)
    #labels = tf.Variable(maxIndexTemp.numpy(), dtype=tf.float32)
    outputs = outputs / temper
    with tf.GradientTape() as tape:
         tape.watch(x)
         output = get_prediction(predict_model, feature_model(x))
         loss = loss_object(labels, output)

    gradient = tape.gradient(loss, x)
    signed_grad = tf.sign(gradient)
    
    # Adding small perturbations to images
    tempInputs = x - epsilon*signed_grad
    del x
    outputs = get_prediction(predict_model, feature_model(tempInputs))
    outputs = outputs / temper
    
    # Calculating the confidence after adding perturbations
    #nnOutputs = nnOutputs - np.max(nnOutputs, axis=1, keepdims=True)
    nnOutputs = outputs - tf.math.reduce_max(outputs, axis=1, keepdims=True)
    #nnOutputs = np.exp(nnOutputs) / np.sum(np.exp(nnOutputs), axis=1, keepdims=True)
    nnOutputs = tf.math.exp(nnOutputs) / tf.math.reduce_sum(tf.math.exp(nnOutputs), axis=1, keepdims=True)

    #confs.extend(np.max(nnOutputs, axis=1))
    confs = tf.math.reduce_max(nnOutputs, axis=1)

    #return np.array(confs)
    return confs


def iterate_data_energy(x, feature_model, predict_model, temper, features=None):
    logits = get_prediction(predict_model, feature_model(x) if features is None else features)
    Ec = -temper * tf.reduce_logsumexp(logits / temper, axis=1)
    
    return -Ec #.numpy()


def run_ood_over_batch(x, feature_model, predict_model, args, num_classes, features=None):
    if np.char.lower(args.score) == 'msp':
        scores = iterate_data_msp(x, feature_model, predict_model, features)

    if np.char.lower(args.score) == 'odin':
        scores = iterate_data_odin(x, feature_model, predict_model, args.epsilon_odin, args.temperature_odin, num_classes, features)
    if np.char.lower(args.score) == 'energy':
        scores = iterate_data_energy(x, feature_model, predict_model, args.temperature_energy, features)

    return scores #.reshape((0,1))
