import jax
import jax.numpy as jnp
import optax
import equinox as eqx
from approxml.mnist import load_model, get_loader
from approxml.simulators import gan_simulator, mvt_norm_simulator
from approxml.utils import gen_simulation_samples
from approxml.scorematching import score_matching_loss
from approxml.optimisers import nn_fit  
from functools import partial
import pickle

class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, key, input_dim, output_dim):
        hidden_dim = 64
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [eqx.nn.Linear(input_dim, hidden_dim, key=key1),
                       eqx.nn.Linear(hidden_dim, hidden_dim, key=key2),
                       eqx.nn.Linear(hidden_dim, output_dim, key=key3)]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)
    
LATENT_DIM = 50
N_RUNS = 100
N_ITER = 500
N_PROP = 150
N_SIM_DST = 150
N_PARAM_DIM = LATENT_DIM
N_DATA_DIM = 256
PRIOR_SIGMA = 1e-1

gen, gen_state = load_model("./saved_models", 
                            "generator_1.eqx", 
                            "generator_state_1.eqx", 
                            jax.random.PRNGKey(42), 
                            LATENT_DIM, 
                            (16,16,1))

gan_simulator = partial(gan_simulator, 
                        generator=gen, 
                        state=gen_state,
                        latent_dim=LATENT_DIM,
                        sigma=PRIOR_SIGMA)

mnist_loader = get_loader((16,16,1),
                          36)
images_batch , _ = next(iter(mnist_loader))


def _grad_fn_impl(theta_t,
                key, 
                lr, 
                obs, 
                model, 
                gen_sim_fn, 
                sm_loss_fn,
                n_iter):
    nn_opt = optax.adam(lr)
    final_model, loss_vals, thetas_q, sims_q, key = nn_fit(model,
                                                    nn_opt,
                                                    theta_t,
                                                    key,
                                                    gen_sim_fn,
                                                    sm_loss_fn,
                                                    n_iter=n_iter)
    
    return jax.vmap(final_model)(obs).sum(0), final_model

grad_fn = jax.jit(lambda theta_t, key, lr, obs, model, gen_sim_fn, sm_loss_fn, n_iter: _grad_fn_impl(theta_t, key, lr, obs, model, gen_sim_fn, sm_loss_fn, n_iter), static_argnums=(5, 6, 7))

def run_sgd(key, 
            theta_init, 
            obs,
            sigma,
            simulator_fn,
            learning_rate
            ):
    optimizer = optax.chain(
        optax.adam(learning_rate=learning_rate) 
    )
    opt_state = optimizer.init(theta_init)

    
    @jax.jit
    def update(model, params, opt_state, key, step):
        key, subkey = jax.random.split(key)
        current_prop_cov = sigma * jnp.eye(N_PARAM_DIM)
        
        def sm_loss_fn(model, sims_q, thetas_q, theta_t):
            pred_q = jax.vmap(jax.vmap(model))(sims_q)
            return score_matching_loss(pred_q, thetas_q, theta_t, N_PROP, N_SIM_DST, current_prop_cov)
        
        sm_loss_fn = jax.jit(sm_loss_fn)
        current_gen_sim_fn = partial(gen_simulation_samples,
            simulator_fn=simulator_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
            n_prop=N_PROP,
            n_sim_dst=N_SIM_DST)
        
        current_grad_fn = partial(grad_fn,
            lr=1e-3,
            obs=obs,
            model=model,
            gen_sim_fn=current_gen_sim_fn,
            sm_loss_fn=sm_loss_fn,
            n_iter=10)
        
        grads, model = current_grad_fn(params, subkey)
        grads = - grads 
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, key, grads, sigma, model

    theta_values = []
    grad_values = []
    sigma_values = []

    theta = theta_init.copy()
    
    key, subkey = jax.random.split(key)
    model = NeuralNetwork(key=subkey, input_dim=N_DATA_DIM, output_dim=N_PARAM_DIM)
    
    for i in range(N_ITER):
        theta, opt_state, key, grads, sigma, model = update(model, theta, opt_state, key, i)
        theta_values.append(theta.copy()) 
        grad_values.append(grads.copy())  
        sigma_values.append(sigma)
    
    avg_theta = jnp.array(theta_values)[-50:,:].mean(axis=0)
    
    return {
        'theta_values': jnp.array(theta_values),
        'grad_values': jnp.array(grad_values),
        'sigma_values': jnp.array(sigma_values),
        'avg_theta': avg_theta,
    }


for obs_idx in range(10):
    print("Running obs_idx: ", obs_idx)
    obs = images_batch[obs_idx].reshape(1, 256)
    obs = jnp.array(obs)

    theta_init = jnp.zeros(N_PARAM_DIM)
    prop_sigma = 0.2
    lr = 5e-2

    res_list = []
    err_list = []

    key = jax.random.PRNGKey(0)

    for i in range(N_RUNS):
        key, subkey = jax.random.split(key)
        res = run_sgd(subkey, 
                    theta_init, 
                    obs, 
                    prop_sigma, 
                    gan_simulator, 
                    lr)
        
        sims = gan_simulator(jax.random.PRNGKey(42), 
                            res['theta_values'][-300:,:].mean(0),
                            100)
        
        err = jnp.linalg.norm(sims - obs, axis=1).mean()
        print(f"SM Run {i} error: {err}")

        err_list.append(err)
        res_list.append(res)

    with open(f'3_mnist_sm_NN_obs_idx_{obs_idx}.pkl', 'wb') as f:
        pickle.dump({
            'res_list': res_list,
            'err_list': err_list
        }, f)