import jax
import optax

def nn_fit(
    model,
    optimizer,
    theta_t,
    key,
    gen_sim_fn,
    fixed_loss,
    n_iter=100
):

    loss_vals = []
    opt_state = optimizer.init(model)
    thetas_q, sims_q, key = gen_sim_fn(key, theta_t)

    @jax.jit
    def step(model,
             opt_state,
             theta_t,
             thetas_q,
             sims_q
    ):
        loss_value, grads = jax.value_and_grad(fixed_loss)(model,
                                                          sims_q,
                                                          thetas_q,
                                                          theta_t)
        
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = optax.apply_updates(model, updates)
        return model, opt_state, loss_value

    for _ in range(n_iter):
        model, opt_state, loss = step(model,
                                    opt_state,
                                    theta_t,
                                    thetas_q,
                                    sims_q)
        loss_vals.append(loss)

    return model, loss_vals, thetas_q, sims_q, key