import tensorflow as tf
import numpy as np

def model_logits(model, ds, batch_size=512):
    return model.predict(ds.batch(batch_size), verbose=0)

def max_softmax_prob(model, ds, opt_temp=None, batch_size=512):
    logits = model_logits(model, ds, batch_size)
    if opt_temp is not None:
        logits = logits / opt_temp
    max_prob = np.max(tf.nn.softmax(logits, axis=-1), axis=1)
    return max_prob
    
def softmax_margin(model, ds, opt_temp=None, batch_size=512):
    logits = model_logits(model, ds, batch_size)
    if opt_temp is not None:
        logits = logits / opt_temp
    softmax_probs = tf.nn.softmax(logits, axis=-1)
    top2 = tf.math.top_k(softmax_probs, k=2).values
    margins = tf.subtract(top2[:, 0], top2[:, 1])
    return margins
    
def max_logits(model, ds, batch_size=512):
    logits = model_logits(model, ds, batch_size)
    return np.max(logits, axis=1)

def logits_margin(model, ds, batch_size=512):
    logits = model_logits(model, ds, batch_size)
    top2 = tf.math.top_k(logits, k=2).values
    margins = tf.subtract(top2[:, 0], top2[:, 1])
    return margins

def negative_entropy(model, ds, opt_temp=None, batch_size=512):
    logits = model_logits(model, ds, batch_size)
    if opt_temp is not None:
        logits = logits / opt_temp
    softmax_probs = tf.nn.softmax(logits, axis=-1)
    neg_ent = tf.reduce_sum(softmax_probs * tf.math.log(softmax_probs + 1e-9), axis=-1)
    return neg_ent

def negative_gini(model, ds, opt_temp=None, batch_size=512):
    logits = model_logits(model, ds, batch_size)
    if opt_temp is not None:
        logits = logits / opt_temp
    softmax_probs = tf.nn.softmax(logits, axis=-1)
    neg_gini = tf.reduce_sum(tf.square(softmax_probs), axis=-1) - 1
    return neg_gini