from functools import partial, reduce

import jax
import utils
from jax import jit
from jax import numpy as jnp
from jax import value_and_grad
from jax.example_libraries import optimizers

NUM_CLASSES = 10

@jit
def loss_fn(params, data, labels):
    logits = data @ params[0] + params[1]
    return utils.softmax_cross_entropy(logits, labels)

@partial(jit, static_argnums=(4, 5))
def train_step(opt_state, data, labels, step, opt_update, get_params):
    labels = jax.nn.one_hot(labels, NUM_CLASSES)
    data = data.reshape(data.shape[0], -1)
    val, grads = value_and_grad(loss_fn)(get_params(opt_state), data, labels)
    opt_state = opt_update(step, grads, opt_state)
    return val, opt_state

def init(sample_batch, learning_rate):
    lr_schedule = optimizers.exponential_decay(learning_rate, decay_rate=0.005, decay_steps=1000)
    opt_init, opt_update, get_params = optimizers.adam(lr_schedule)
    w_key, b_key = jax.random.split(jax.random.PRNGKey(0))
    feature_size = reduce((lambda a, b: a * b), sample_batch[0].shape[1:], 1)
    params = (jax.random.normal(w_key, (feature_size, 10)), jax.random.normal(b_key, (10,)))
    opt_state = opt_init(params)
    return opt_state, opt_update, get_params

@jit
def get_acc(params, data, labels):
    data = data.reshape(data.shape[0], -1)
    logits = data @ params[0] + params[1]
    return jnp.mean(jnp.argmax(logits, axis=1) == labels)