import jax
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm
from approxml.aml_kde import silverman_bandwidth, kde_logdensity_modified
from approxml.simulators import mvt_norm_simulator
import pickle

for hp in [1e-4, 1e-3, 1e-2, 1e-1, 0.5, 1.0, 1.5, 2.0, 5.0, 10.0, 50.0, 100.0]:
    for sim_budget in [1e2, 5e2, 1e3, 5e3, 1e4, 5e4, 1e5, 5e5, 1e6]:
        N_RUNS = 100
        N_PROP = int(sim_budget)
        N_SIM_DST = 1
        N_PARAM_DIM = 2
        N_DATA_DIM = N_PARAM_DIM
        MOD_SIGMA = jnp.eye(N_PARAM_DIM)
        PROP_SIGMA = hp
        N_OBS = 10
        C_K = hp

        simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)

        theta_true = jnp.ones(N_PARAM_DIM)

        obs = mvt_norm_simulator(jax.random.PRNGKey(0), theta_true, N_OBS)

        theta = obs.mean(0)

        key = jax.random.PRNGKey(0)

        @partial(jax.jit,static_argnums=(1, 4))
        def sm_grads(key,
                    sim_fn,
                    theta,
                    obs,
                    prop_sigma,
                    ):
            obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1)

            current_gen_sim_fn = partial(gen_simulation_samples,
                simulator_fn=sim_fn,
                prop_sim_fn=partial(mvt_norm_simulator, cov=prop_sigma * jnp.eye(N_PARAM_DIM)),
                n_prop=N_PROP,
                n_sim_dst=N_SIM_DST)

            current_grad_fn = partial(fit_linear_sm,
                gen_sim_fn=current_gen_sim_fn,
                grad_log_prop_fn=partial(grad_log_normal, cov=prop_sigma * jnp.eye(N_PARAM_DIM)),
                n_sim_dst=N_SIM_DST,
                n_prop=N_PROP)

            key, subkey = jax.random.split(key)
            W, _, _, _ = current_grad_fn(subkey, theta)
            grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
            return grads

        @partial(jax.jit,static_argnums=(1, 4, 5)) 
        def kde_grads(key,
                    sim_fn,
                    theta,
                    obs,
                    n_sim,
                    c_k):
            
            kde_logdensity_batch  =  jax.vmap(kde_logdensity_modified,in_axes=(0, None, None))

            key, key_delta, key_p, key_m = jax.random.split(key, 4)
            delta   = jax.random.choice(key_delta, jnp.array([-1.0, 1.0]), shape=(N_PARAM_DIM,))
            t_plus  = theta + c_k * delta
            t_minus = theta - c_k * delta

            sims_plus  = sim_fn(key_p, t_plus,  n_sim)  
            sims_minus = sim_fn(key_m, t_minus, n_sim)   

            h_est = silverman_bandwidth(
                jnp.vstack([sims_plus, sims_minus])
            )
            h = h_est

            ll_plus  = kde_logdensity_batch(obs, sims_plus,  h)  
            ll_minus = kde_logdensity_batch(obs, sims_minus, h)   

            diff   = jnp.sum(ll_plus - ll_minus)                   
            g_hat  = delta * diff / (2.0 * c_k)                    
            return g_hat

        kde_grad_list = []
        for _ in range(N_RUNS):
            key, subkey = jax.random.split(key)
            kde_grad_list.append(kde_grads(subkey, 
                    simulator_fn, 
                    theta, 
                    obs, 
                    (N_SIM_DST * N_PROP) // 2, 
                    C_K))

        with open(f'kde_comparison_results_budget_{sim_budget}_hp_{hp}.pkl', 'wb') as f:
            pickle.dump({'kde_grad_list': kde_grad_list,
                        'N_RUNS': N_RUNS,
                        'N_SIM_DST': N_SIM_DST,
                        'N_PROP': N_PROP,
                        'N_PARAM_DIM': N_PARAM_DIM,
                        'N_DATA_DIM': N_DATA_DIM,
                        'theta': theta,
                        'c_k': C_K,
                        }, f)

        sm_grad_list = []
        for _ in range(N_RUNS):
            key, subkey = jax.random.split(key)
            sm_grad_list.append(sm_grads(subkey, 
                    simulator_fn, 
                    theta, 
                    obs, 
                    PROP_SIGMA))

        with open(f'sm_comparison_results_budget_{sim_budget}_hp_{hp}.pkl', 'wb') as f:
            pickle.dump({'sm_grad_list': sm_grad_list,
                        'N_RUNS': N_RUNS,
                        'N_SIM_DST': N_SIM_DST,
                        'N_PROP': N_PROP,
                        'N_PARAM_DIM': N_PARAM_DIM,
                        'theta': theta,
                        'PROP_SIGMA': PROP_SIGMA,
                        'N_OBS': N_OBS,
                        }, f)
