import jax
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm
from approxml.aml_kde import silverman_bandwidth, kde_logdensity_modified
from approxml.simulators import mvt_norm_simulator
import time
import pickle

for dim in [2, 5, 10, 50, 100]:
    N_RUNS = 1000
    N_PROP = 100
    N_SIM_DST = 250
    N_PARAM_DIM = dim
    N_DATA_DIM = N_PARAM_DIM
    MOD_SIGMA = jnp.eye(N_PARAM_DIM)
    PROP_SIGMA = 1e-3
    N_OBS = 10
    C_K = 1e-1

    simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)

    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)

    theta_true = jax.random.uniform(subkey, (N_PARAM_DIM,), minval=-10, maxval=10)

    key, subkey = jax.random.split(key)
    obs = mvt_norm_simulator(jax.random.PRNGKey(0), theta_true, N_OBS)

    theta = obs.mean(axis=0)

    def sm_grads(key,
                sim_fn,
                theta,
                obs,
                prop_sigma,
                ):
        obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1)

        current_gen_sim_fn = partial(gen_simulation_samples,
            simulator_fn=sim_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=prop_sigma * jnp.eye(N_PARAM_DIM)),
            n_prop=N_PROP,
            n_sim_dst=N_SIM_DST)

        current_grad_fn = partial(fit_linear_sm,
            gen_sim_fn=current_gen_sim_fn,
            grad_log_prop_fn=partial(grad_log_normal, cov=prop_sigma * jnp.eye(N_PARAM_DIM)),
            n_sim_dst=N_SIM_DST,
            n_prop=N_PROP)

        key, subkey = jax.random.split(key)
        W, _, _, _ = current_grad_fn(subkey, theta)
        grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
        return grads

    def kde_grads(key,
                sim_fn,
                theta,
                obs,
                n_sim,
                c_k):
        
        kde_logdensity_batch  =  jax.vmap(kde_logdensity_modified,in_axes=(0, None, None))

        key, key_delta, key_p, key_m = jax.random.split(key, 4)
        delta   = jax.random.choice(key_delta, jnp.array([-1.0, 1.0]), shape=(N_PARAM_DIM,))
        t_plus  = theta + c_k * delta
        t_minus = theta - c_k * delta

        sims_plus  = sim_fn(key_p, t_plus,  n_sim)  
        sims_minus = sim_fn(key_m, t_minus, n_sim)  

        h_est = silverman_bandwidth(
            jnp.vstack([sims_plus, sims_minus])
        )
        h = h_est

        ll_plus  = kde_logdensity_batch(obs, sims_plus,  h) 
        ll_minus = kde_logdensity_batch(obs, sims_minus, h)  

        diff   = jnp.sum(ll_plus - ll_minus)                   
        g_hat  = delta * diff / (2.0 * c_k)                    
        return g_hat

    kde_ns_list = []
    for _ in range(N_RUNS):
        start_ns = time.perf_counter_ns()
        key, subkey = jax.random.split(key)
        _ = kde_grads(subkey, 
                simulator_fn, 
                theta, 
                obs, 
                (N_SIM_DST * N_PROP) // 2, 
                C_K)
        kde_ns_list.append(time.perf_counter_ns() - start_ns)

    with open(f'kde_comparison_results_ns_dim_{dim}.pkl', 'wb') as f:
        pickle.dump({'kde_ns_list': kde_ns_list,
                    'N_RUNS': N_RUNS,
                    'N_SIM_DST': N_SIM_DST,
                    'N_PROP': N_PROP,
                    'N_PARAM_DIM': N_PARAM_DIM,
                    'N_DATA_DIM': N_DATA_DIM,
                    'theta': theta,
                    'c_k': C_K,
                    }, f)


    sm_ns_list = []
    for _ in range(N_RUNS):
        start_ns = time.perf_counter_ns()
        key, subkey = jax.random.split(key)
        _ = sm_grads(subkey, 
                simulator_fn, 
                theta, 
                obs, 
                PROP_SIGMA)
        sm_ns_list.append(time.perf_counter_ns() - start_ns)

    with open(f'sm_comparison_results_ns_dim_{dim}.pkl', 'wb') as f:
        pickle.dump({'sm_ns_list': sm_ns_list,
                    'N_RUNS': N_RUNS,
                    'N_SIM_DST': N_SIM_DST,
                    'N_PROP': N_PROP,
                    'N_PARAM_DIM': N_PARAM_DIM,
                    'theta': theta,
                    'PROP_SIGMA': PROP_SIGMA,
                    'N_OBS': N_OBS,
                    }, f)
