import jax
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm
from approxml.simulators import mvt_norm_simulator
import time
import pickle

for sigma in [1e-2, 5e-2, 1e-1, 2e-1, 3e-1, 4e-1, 5e-1, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]:
    N_RUNS = 100
    N_PROP = 100    
    N_SIM_DST = 250
    N_PARAM_DIM = 2
    N_DATA_DIM = N_PARAM_DIM
    MOD_SIGMA = jnp.eye(N_PARAM_DIM)
    PROP_SIGMA = sigma
    N_OBS = 10

    simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)

    theta_true = jnp.array([0.0, 0.0])

    obs = mvt_norm_simulator(jax.random.PRNGKey(0), theta_true, N_OBS)

    theta = jnp.array([1.0, 1.0])

    key = jax.random.PRNGKey(0)
    
    def sm_grads(key,
                sim_fn,
                theta,
                obs,
                prop_sigma,
                ):
        obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1)

        current_gen_sim_fn = partial(gen_simulation_samples,
            simulator_fn=sim_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=prop_sigma * jnp.eye(N_PARAM_DIM)),
            n_prop=N_PROP,
            n_sim_dst=N_SIM_DST)

        current_grad_fn = partial(fit_linear_sm,
            gen_sim_fn=current_gen_sim_fn,
            grad_log_prop_fn=partial(grad_log_normal, cov=prop_sigma * jnp.eye(N_PARAM_DIM)),
            n_sim_dst=N_SIM_DST,
            n_prop=N_PROP)

        key, subkey = jax.random.split(key)
        W, _, _, _ = current_grad_fn(subkey, theta)
        grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
        return grads

    sm_grad_list = []
    sm_ns_list = []
    for _ in range(N_RUNS):
        start_ns = time.perf_counter_ns()
        key, subkey = jax.random.split(key)
        sm_grad_list.append(sm_grads(subkey, 
                simulator_fn, 
                theta, 
                obs, 
                PROP_SIGMA))
        sm_ns_list.append(time.perf_counter_ns() - start_ns)

    with open(f'sm_comparison_results_sigma_{sigma}.pkl', 'wb') as f:
        pickle.dump({'sm_grad_list': sm_grad_list,
                    'sm_ns_list': sm_ns_list,
                    'N_RUNS': N_RUNS,
                    'N_SIM_DST': N_SIM_DST,
                    'N_PROP': N_PROP,
                    'N_PARAM_DIM': N_PARAM_DIM,
                    'theta': theta,
                    'PROP_SIGMA': PROP_SIGMA,
                    'N_OBS': N_OBS,
                    'theta_mle': obs.mean(axis=0),
                    }, f)