import jax
import jax.numpy as jnp
import optax
import equinox as eqx
from approxml.mnist import load_model, get_loader, invert_latent
from approxml.simulators import gan_simulator, mvt_norm_simulator
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm
from functools import partial
from tqdm import tqdm
from approxml.aml_kde import spsa_aml_iid
from itertools import product
import pickle

LATENT_DIM = 50
N_RUNS = 100
N_ITER = 500
N_PROP = 150
N_SIM_DST = 150
N_PARAM_DIM = LATENT_DIM
N_DATA_DIM = 256

gen, gen_state = load_model("./saved_models", 
                            "generator_1.eqx", 
                            "generator_state_1.eqx", 
                            jax.random.PRNGKey(42), 
                            LATENT_DIM, 
                            (16,16,1))


mnist_loader = get_loader((16,16,1),
                          36)
images_batch , _ = next(iter(mnist_loader))

def run_sgd(key, 
            theta_init, 
            obs,
            sigma,
            simulator_fn,
            learning_rate,
            lamb_val):
    obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1) 
    
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate=learning_rate)
    )
    opt_state = optimizer.init(theta_init)

    @jax.jit
    def update(params, opt_state, key, step):
        key, subkey = jax.random.split(key)
        current_sigma = sigma
        current_prop_cov = current_sigma * jnp.eye(N_PARAM_DIM)
        current_gen_sim_fn = partial(gen_simulation_samples,
            simulator_fn=simulator_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
            n_prop=N_PROP,
            n_sim_dst=N_SIM_DST)
        
        current_grad_fn = partial(fit_linear_sm,
            gen_sim_fn=current_gen_sim_fn,
            grad_log_prop_fn=partial(grad_log_normal, cov=current_prop_cov),
            n_sim_dst=N_SIM_DST,
            n_prop=N_PROP,
            lamb=lamb_val)
        W, _, _, _ = current_grad_fn(subkey, params)
        grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
        grads = - grads  
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, key, grads, current_sigma

    theta_values = []
    grad_values = []
    sigma_values = []

    theta = theta_init.copy()
    for i in range(N_ITER):
        theta, opt_state, key, grads, current_sigma = update(theta, opt_state, key, i)
        theta_values.append(theta.copy()) 
        grad_values.append(grads.copy())  
        sigma_values.append(current_sigma) 

    return {
        'theta_values': jnp.array(theta_values),
        'grad_values': jnp.array(grad_values),
        'sigma_values': jnp.array(sigma_values),
    }

for prior_sigma in [5e-2, 1e-1, 5e-1, 1e0, 5e0, 1e1]:
    print("Running prior_sigma: ", prior_sigma)
    obs_idx = 0
    obs = images_batch[obs_idx].reshape(1, 256)
    obs = jnp.array(obs)

    gan_simulator = partial(gan_simulator, 
                            generator=gen, 
                            state=gen_state,
                            latent_dim=LATENT_DIM,
                            sigma=prior_sigma)

    theta_init = jnp.zeros(N_PARAM_DIM)
    prop_sigma = 2 * prior_sigma
    lr = 5e-2
    lamb_val = 10

    res_list = []
    err_list = []

    key = jax.random.PRNGKey(0)

    for i in range(N_RUNS):
        key, subkey = jax.random.split(key)
        res = run_sgd(subkey, 
                    theta_init, 
                    obs, 
                    prop_sigma, 
                    gan_simulator, 
                    lr,
                    lamb_val)
        
        sims = gan_simulator(jax.random.PRNGKey(42), 
                            res['theta_values'][-300:,:].mean(0),
                            100)
        
        err = jnp.linalg.norm(sims - obs, axis=1).mean()
        print(f"SM Run {i} error: {err}")

        err_list.append(err)
        res_list.append(res)

    with open(f'3_mnist_sm_prior_sigma_{prior_sigma}.pkl', 'wb') as f:
        pickle.dump({
            'res_list': res_list,
            'err_list': err_list,
            'prior_sigma': prior_sigma
        }, f)

    a_grid = [1e-4, 1e-3, 5e-3, 1e-2, 5e-2]
    c_grid = [1e-4, 1e-3, 5e-3, 1e-2, 5e-2]
    theta_init = jnp.zeros(N_PARAM_DIM)

    results = []

    key = jax.random.PRNGKey(0)
    for a_val, c_val in product(a_grid, c_grid):
        key, subkey = jax.random.split(key)
        theta_hat = spsa_aml_iid(
            subkey,
            gan_simulator,
            obs, 
            theta_init, 
            iterations=N_ITER,
            n_sim=(N_SIM_DST * N_PROP) // 2, 
            a=a_val, 
            c=c_val
        )
        
        key, subkey = jax.random.split(key)
        sims = gan_simulator(subkey, theta_hat, 100)
        err = jnp.linalg.norm(sims - obs, axis=1).mean()
        results.append((a_val, c_val, err))

    results.sort(key=lambda tup: tup[2])
    best_a, best_c, best_err = results[0]

    print("Grid search results (top 5):")
    for a_val, c_val, err in results[:5]:
        print(f"  a={a_val:.3f}, c={c_val:.3f}  -->  error = {err:.4f}")

    print(f"\nBEST:  a = {best_a},  c = {best_c},  error = {best_err:.4f}")

    res_aml_list = []
    res_aml_err_list = []
    theta_init = jnp.zeros(N_PARAM_DIM)

    for _ in tqdm(range(N_RUNS)):
        key, subkey = jax.random.split(key)
        res_aml = spsa_aml_iid(
                    subkey,
                    gan_simulator,
                    obs, 
                    theta0=theta_init, 
                    a=best_a,
                    c=best_c,
                    iterations=N_ITER,
                    n_sim=(N_SIM_DST * N_PROP) // 2
                    )

        sims = gan_simulator(jax.random.PRNGKey(42), 
                            res_aml,
                            100)
        
        err = jnp.linalg.norm(sims - obs, axis=1).mean()

        print(f"AML Run {i} error: {err}")

        res_aml_err_list.append(err)
        res_aml_list.append(res_aml)

    with open(f'3_mnist_aml_prior_sigma_{prior_sigma}.pkl', 'wb') as f:
        pickle.dump({
            'res_list': res_aml_list,
            'err_list': res_aml_err_list,
            'prior_sigma': prior_sigma
        }, f)

    def recon_loss(mean_latent, key, obs, n_mc):
        fake = gan_simulator(key, mean_latent, n_mc)       
        obs   = jnp.broadcast_to(obs, fake.shape)                       
        return jnp.mean((fake - obs) ** 2)                               

    loss_and_grad = eqx.filter_value_and_grad(recon_loss)

    theta_init = jnp.zeros(N_PARAM_DIM)

    det_res_list = []
    det_err_list = []
    key = jax.random.PRNGKey(42)
    for i in range(N_RUNS):
        key, subkey = jax.random.split(key)
        mu_star = invert_latent(loss_and_grad, 
                                obs, 
                                n_steps=1000, 
                                n_mc=500, 
                                lr=5e-1, 
                                theta_init=theta_init,
                                key=subkey)
        
        det_res_list.append(mu_star)
        sims = gan_simulator(key, mu_star, 100)
        err = jnp.linalg.norm(sims - obs, axis=1).mean()
        det_err_list.append(err)
        det_res_list.append(mu_star)

        print(f"Det Run {i} error: {err}")

    with open(f'3_mnist_det_prior_sigma_{prior_sigma}.pkl', 'wb') as f:
        pickle.dump({
            'res_list': det_res_list,
            'err_list': det_err_list,
            'prior_sigma': prior_sigma
        }, f)