import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from approxml.cosmo import lensing_simulator, sample_lensing_prior
from approxml.aml_kde import spsa_aml_iid
import pickle
import importlib
import types
import tensorflow_probability as tfp
for _candidate in (
        'tensorflow_probability.substrates.jax', 
        'tensorflow_probability.experimental.substrates.jax'):    
    try:
        _jax_backend = importlib.import_module(_candidate)
        break
    except ModuleNotFoundError:
        _jax_backend = None
if _jax_backend is None:
    raise ImportError("Couldn’t locate the JAX substrate inside tensorflow_probability.")

if not hasattr(tfp, 'substrates'):
    tfp.substrates = types.SimpleNamespace()
tfp.substrates.jax = _jax_backend     

if not hasattr(tfp, 'experimental'):
    tfp.experimental = types.SimpleNamespace()
if not hasattr(tfp.experimental, 'substrates'):
    tfp.experimental.substrates = types.SimpleNamespace()
tfp.experimental.substrates.jax = _jax_backend 


N_ITER = 100
N_RUNS = 100
N_PROP = 5    
N_SIM_DST = 20
N_PARAM_DIM = 6
N_DATA_DIM = 6
MOD_SIGMA = jnp.eye(N_PARAM_DIM)

lognormal_shifts_params = np.load("../data/lognormal_shifts_LSSTY10_om_s8_w_bin.npy")

a_file = open('../data/params_compressor/opt_state_resnet_vmim.pkl', "rb")
opt_state_resnet= pickle.load(a_file)

a_file = open('../data/params_compressor/params_nd_compressor_vmim.pkl', "rb")
parameters_compressor= pickle.load(a_file)

@partial(jax.jit, static_argnames=("n_sim", "compress"))
def simulator_fn_jit(
        key,
        unbound_params,
        n_sim,
        compress = True):
    p0, p1, p2, p3, p4, p5 = unbound_params 
    params_bound = jnp.array([jnp.exp(p0), jnp.exp(p1), jnp.exp(p2), p3, p4 ,p5])

    return lensing_simulator(
        key,
        params_bound,
        n_sim,
        compress=compress,
        opt_state_resnet=opt_state_resnet,        
        parameters_compressor=parameters_compressor,
        lognormal_shifts_params=lognormal_shifts_params,
    )
    
sim_fn = simulator_fn_jit

params = jnp.array([0.2664, 0.0492, 0.831, 0.6727, 0.9645, -1.0])
params = params.at[2].set(jnp.log(params[2]))
params = params.at[1].set(jnp.log(params[1]))
params = params.at[0].set(jnp.log(params[0]))

n_obs = 250
obs = sim_fn(jax.random.PRNGKey(0), params, n_obs)
theta_init = sample_lensing_prior(jax.random.PRNGKey(0))

best_a, best_c = 1e-4, 1e-1

key = jax.random.PRNGKey(0)
res_list = []
for i in range(N_RUNS):
    key, subkey = jax.random.split(key)
    obs = sim_fn(subkey, params, n_obs)

    key, subkey = jax.random.split(key)
    theta_hat = spsa_aml_iid(
        subkey,
        sim_fn,
        obs, 
        theta_init, 
        iterations=100,
        n_sim=(N_SIM_DST * N_PROP) // 2, 
        a=best_a, 
        c=best_c
    )

    key, subkey = jax.random.split(key)
    sims = sim_fn(subkey, theta_hat, n_obs)
    err = jnp.linalg.norm(sims - obs, axis=1).mean()
    print(f"Run {i+1} completed, error = {err}, mse = {jnp.linalg.norm(theta_hat - params)}")
    res_list.append(
        {
            "n_obs": n_obs,
            "a": best_a,
            "c": best_c,
            "theta_hat": theta_hat,
            "err": err,
            "mse": jnp.linalg.norm(theta_hat - params)
        }
    )

with open(f"cosmo_aml_n_obs_{n_obs}_patched.pkl", "wb") as f:
    pickle.dump(res_list, f)

    
    
