import utils
from jax import vmap, lax # for auto-vectorizing functions
from jax import jacfwd, jacrev,hessian
from functools import partial # for use with vmap
from jax import jit, grad, random # for compiling functions for speedup
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, Relu, LeakyRelu, Flatten, LogSoftmax # neural network layers
from jax.nn.initializers import zeros
from jax.nn import leaky_relu
from jax.experimental.stax import elementwise
from jax import numpy as np
import numpy as onp
from jax.interpreters import xla

from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

import numpy.random as npr

"""
vanilla cross entropy loss given logits
"""
def cross_entropy_logits(logits, targets):
    logits = stax.logsoftmax(logits)  # log normalize
    return -np.mean(np.sum(logits * targets, axis=1))  # cross entropy loss    

"""
vanilla cross entropy loss
"""
def loss(params, inputs, targets):
    logits = net_apply(params, inputs)
    logits = stax.logsoftmax(logits)  # log normalize
    return -np.mean(np.sum(logits * targets, axis=1))  # cross entropy loss
"""
loss function for calculating gradients of loss w.r.t. input image
"""
def lo(batch, params):
    inputs, targets = batch
    logits = net_apply(params, inputs)
    logits = stax.logsoftmax(logits)  # log normalize
    return -np.mean(np.sum(logits * targets, axis=1))  # cross entropy loss