import jax
import optax
from functools import partial

@partial(jax.jit, static_argnames=['loss_fn'])
# @jax.jit
def train_step(state, batch, loss_fn):
    """Train for a single step."""
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params, batch, state)
    state = state.apply_gradients(grads=grads)
    return state


def squared_error_loss(params, batch, state):
    logits = state.apply_fn({'params': params}, batch[0])
    loss = 1/2 * optax.squared_error(predictions=logits, targets=batch[1]).sum()
    return loss

def cross_entropy_loss(params, batch, state):
    logits = state.apply_fn({'params': params}, batch[0])
    loss = optax.softmax_cross_entropy(logits=logits, labels=batch[1]).sum()
    return loss

def bin_cross_entropy_loss(params, batch, state):
    logits = state.apply_fn({'params': params}, batch[0])
    loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).sum()
    return loss