import jax
import jax.numpy as np
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

import numpy as onp

"""
Loss Function: 
(parameters, state, inputs, targets, net_apply) -> loss_value, new_state
"""
@jit
def nll(logits, labels):
    logits = jax.nn.log_softmax(logits, axis=-1)
    return -np.mean(np.sum(logits * labels, axis=-1))

@jit
def entropy(logits, labels):
    logits = jax.nn.log_softmax(logits, axis=-1)
    probs = np.exp(logits)
    return -np.mean(np.sum(logits * probs, axis=-1))

@jit
def accuracy(logits, labels):
    predicted_labels = np.argmax(logits, axis=-1)
    targets = np.argmax(labels, axis=-1)
    return np.mean(predicted_labels==targets)

@jit
def brier(logits, labels):
    probs = jax.nn.softmax(logits, axis=-1)
    return np.mean(np.sum((probs - labels)**2, axis=-1))

@jit
def mse(predictions, targets):
    return np.mean(np.sum((predictions - targets) ** 2, axis=-1))

# useful for when we use an unsupervised loss that's implemented as a regularizer
@jit
def zero_loss(logits, labels):
    return 0. 

def ece(logits, labels, n_buckets=20, uniform_buckets=False):
    """
    Differs from other losses in that it needs to be executed on the whole dataset at once.
    """
    probs = jax.nn.softmax(logits, axis=-1)
    confs = np.max(probs, axis=-1)
    confs = onp.array(confs)
    predicted_labels = np.argmax(logits, axis=-1)
    targets = np.argmax(labels, axis=-1)
    accs = predicted_labels == targets
    accs = onp.array(accs)
    buckets = [[] for i in range(n_buckets)]
    bucketed_accs = [[] for i in range(n_buckets)]

    sorted_confs = onp.sort(confs)
    index_separation = len(confs) // n_buckets
    boundaries = [sorted_confs[i * index_separation] for i in range(n_buckets)]


    for j, p in enumerate(confs):
        if not uniform_buckets:
            for index, boundary in enumerate(boundaries):
                if index == n_buckets - 1 or p < boundaries[index+1]:
                    buckets[index].append(p)
                    bucketed_accs[index].append(accs[j])
                    break
        else:
            index = int(p * 100)  // 5
            if index == n_buckets:
                index = n_buckets - 1 
            buckets[index].append(j)

    buckets = [onp.array(b) for b in buckets]
    bucketed_accs = [onp.mean(np.array(acc)) for acc in bucketed_accs]
    mean_confs = [onp.mean(b) for b in buckets]
    eces = [onp.abs(acc - mu) for acc, mu in zip(bucketed_accs, mean_confs)]

    total_results = 0.
    num_entries = 0
    for i, ece in enumerate(eces):
        if not np.isnan(ece):
            total_results += ece * len(buckets[i])
        num_entries += len(buckets[i])

    return total_results / num_entries
    
