import jax
import jax.numpy as jnp
import numpy as np 
from functools import partial
from approxml.utils import gen_simulation_samples, NeuralNetwork
from approxml.optimisers import nn_fit
from approxml.scorematching import score_matching_loss
from approxml.simulators import mvt_norm_simulator
from approxml.cosmo import lensing_simulator, sample_lensing_prior
import optax
import pickle
import importlib
import types
import tensorflow_probability as tfp
for _candidate in (
        'tensorflow_probability.substrates.jax', 
        'tensorflow_probability.experimental.substrates.jax'):    
    try:
        _jax_backend = importlib.import_module(_candidate)
        break
    except ModuleNotFoundError:
        _jax_backend = None
if _jax_backend is None:
    raise ImportError("Couldn’t locate the JAX substrate inside tensorflow_probability.")

if not hasattr(tfp, 'substrates'):
    tfp.substrates = types.SimpleNamespace()
tfp.substrates.jax = _jax_backend     

if not hasattr(tfp, 'experimental'):
    tfp.experimental = types.SimpleNamespace()
if not hasattr(tfp.experimental, 'substrates'):
    tfp.experimental.substrates = types.SimpleNamespace()
tfp.experimental.substrates.jax = _jax_backend 


N_ITER = 100
N_RUNS = 100
N_PROP = 25  
N_SIM_DST = 4
N_PARAM_DIM = 6
N_DATA_DIM = 6
MOD_SIGMA = jnp.eye(N_PARAM_DIM)
sigma = 1e-3
lr = 1e-2

lognormal_shifts_params = np.load("../../data/lognormal_shifts_LSSTY10_om_s8_w_bin.npy")

a_file = open('../../data/params_compressor/opt_state_resnet_vmim.pkl', "rb")
opt_state_resnet= pickle.load(a_file)

a_file = open('../../data/params_compressor/params_nd_compressor_vmim.pkl', "rb")
parameters_compressor= pickle.load(a_file)


def _grad_fn_impl(theta_t,
                key, 
                lr, 
                obs, 
                model, 
                gen_sim_fn, 
                sm_loss_fn,
                n_iter):
    nn_opt = optax.adam(lr)
    final_model, loss_vals, thetas_q, sims_q, key = nn_fit(model,
                                                    nn_opt,
                                                    theta_t,
                                                    key,
                                                    gen_sim_fn,
                                                    sm_loss_fn,
                                                    n_iter=n_iter)
    
    return jax.vmap(final_model)(obs).sum(0), final_model

grad_fn = jax.jit(lambda theta_t, key, lr, obs, model, gen_sim_fn, sm_loss_fn, n_iter: _grad_fn_impl(theta_t, key, lr, obs, model, gen_sim_fn, sm_loss_fn, n_iter), static_argnums=(5, 6, 7))


def run_sgd(key, 
            theta_init, 
            obs,
            sigma,
            simulator_fn,
            learning_rate
            ):
    optimizer = optax.chain(
        optax.clip_by_global_norm(1e-1),
        optax.adam(learning_rate)
    )
    opt_state = optimizer.init(theta_init)

    @jax.jit
    def update(model, params, opt_state, key, step):
        key, subkey = jax.random.split(key)
        
        current_prop_cov = sigma * jnp.eye(N_PARAM_DIM)
        
        def sm_loss_fn(model, sims_q, thetas_q, theta_t):
            pred_q = jax.vmap(jax.vmap(model))(sims_q)
            return score_matching_loss(pred_q, thetas_q, theta_t, N_PROP, N_SIM_DST, current_prop_cov)
        
        sm_loss_fn = jax.jit(sm_loss_fn)
        current_gen_sim_fn = partial(gen_simulation_samples,
            simulator_fn=simulator_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
            n_prop=N_PROP,
            n_sim_dst=N_SIM_DST)
        
        current_grad_fn = partial(grad_fn,
            lr=1e-2,
            obs=obs,
            model=model,
            gen_sim_fn=current_gen_sim_fn,
            sm_loss_fn=sm_loss_fn,
            n_iter=10)

        grads, model = current_grad_fn(params, subkey)
        grads = - grads 
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, key, grads, sigma, model

    theta_values = []
    grad_values = []
    sigma_values = []

    theta = theta_init.copy()
    key, subkey = jax.random.split(key)
    model = NeuralNetwork(key=subkey, input_dim=N_DATA_DIM, output_dim=N_PARAM_DIM)
    for i in range(N_ITER):
        theta, opt_state, key, grads, sigma, model = update(model, theta, opt_state, key, i)
        
        theta_values.append(theta.copy())
        grad_values.append(grads.copy()) 
        sigma_values.append(sigma) 
    
    return {
        'theta_values': jnp.array(theta_values),
        'grad_values': jnp.array(grad_values),
        'sigma_values': jnp.array(sigma_values),
    }

@partial(jax.jit, static_argnames=("n_sim", "compress"))
def simulator_fn_jit(
        key,
        unbound_params,
        n_sim,
        compress = True):
    p0, p1, p2, p3, p4, p5 = unbound_params 
    params_bound = jnp.array([jnp.exp(p0), jnp.exp(p1), jnp.exp(p2), p3, p4 ,p5])

    return lensing_simulator(
        key,
        params_bound,
        n_sim,
        compress=compress,
        opt_state_resnet=opt_state_resnet,      
        parameters_compressor=parameters_compressor,
        lognormal_shifts_params=lognormal_shifts_params,
    )
    
sim_fn = simulator_fn_jit

params = jnp.array([0.2664, 0.0492, 0.831, 0.6727, 0.9645, -1.0])
params = params.at[2].set(jnp.log(params[2]))
params = params.at[1].set(jnp.log(params[1]))
params = params.at[0].set(jnp.log(params[0]))

for n_obs in [100, 250, 500, 750, 1000]:
    obs = sim_fn(jax.random.PRNGKey(0), params, n_obs)
    theta_init = sample_lensing_prior(jax.random.PRNGKey(0))

    key = jax.random.PRNGKey(0)
    res_list = []
    for i in range(N_RUNS):
        key, subkey = jax.random.split(key)
        obs = sim_fn(subkey, params, n_obs)

        key, subkey = jax.random.split(key)
        res_dict = run_sgd(subkey,
                           theta_init,
                           obs,
                           sigma,
                           sim_fn,
                           lr)
        avg_theta = res_dict['theta_values'][-50:].mean(0)
        key, subkey = jax.random.split(key)
        sims = sim_fn(subkey, avg_theta, n_obs)
        err = jnp.linalg.norm(sims - obs, axis=1).mean()
        print(f"Run {i+1} completed, error = {err}, mse = {jnp.linalg.norm(avg_theta - params)}")
        
        res_list.append(
            {
                "n_obs": n_obs,
                "theta_hat": avg_theta,
                "err": err,
                "mse": jnp.linalg.norm(avg_theta - params),
                "sigma": sigma,
                "lr": lr,
                "res_dict": res_dict
            }
        )

    with open(f"cosmo_nn_n_obs_{n_obs}.pkl", "wb") as f:
        pickle.dump(res_list, f)

    
    
