import jax
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm
from approxml.simulators import mvt_norm_simulator
from tqdm import tqdm
import optax

for N_OBS in [1, 50, 100, 250, 500]:
    N_PARAM_DIM = 5

    true_parameters = jnp.ones(N_PARAM_DIM)
    obs = mvt_norm_simulator(jax.random.PRNGKey(0), 
                                true_parameters, 
                                N_OBS)

    N_ITER = 100
    N_RUNS = 100
    N_PROP = 50000    
    N_SIM_DST = 1
    N_DATA_DIM = N_PARAM_DIM
    MOD_SIGMA = jnp.eye(N_PARAM_DIM)
    LR = 1e-3
    SIGMA = 1e-1
    N_BOOT = 100

    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    init_theta = jnp.zeros(N_PARAM_DIM)
    simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)

    def robbins_monro_schedule(
        init_value = 1.0,
        alpha = 0.75,
        offset = 10.0
    ):
        def schedule(count: int) -> jnp.ndarray:
            t = offset + (count + 1)
            return init_value / (t ** alpha)
        
        return schedule

    def run_sgd(key, 
                theta_init, 
                obs,
                sigma_init,
                simulator_fn,
                learning_rate
                ):
        sigma_final = 0.01
        obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1) 

        def get_sigma(step):
            progress = step / N_ITER
            return sigma_init * (1.0 - progress) + sigma_final * progress

        schedule_fn = robbins_monro_schedule(init_value=learning_rate)
        optimizer = optax.chain(
            optax.scale_by_schedule(schedule_fn),
            optax.rmsprop(1.0)
        )
        opt_state = optimizer.init(theta_init)

        @jax.jit
        def update(params, opt_state, key, step):
            key, subkey = jax.random.split(key)
            
            current_sigma = get_sigma(step)
            current_prop_cov = current_sigma * jnp.eye(N_PARAM_DIM)
            
            current_gen_sim_fn = partial(gen_simulation_samples,
                simulator_fn=simulator_fn,
                prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
                n_prop=N_PROP,
                n_sim_dst=N_SIM_DST)
            
            current_grad_fn = partial(fit_linear_sm,
                gen_sim_fn=current_gen_sim_fn,
                grad_log_prop_fn=partial(grad_log_normal, cov=current_prop_cov),
                n_sim_dst=N_SIM_DST,
                n_prop=N_PROP)
            
            W, _, _, _ = current_grad_fn(subkey, params)
            grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
            grads = - grads  
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, key, grads, current_sigma

        theta_values = []
        grad_values = []
        sigma_values = []

        theta = theta_init.copy()
        for i in range(N_ITER):
            theta, opt_state, key, grads, current_sigma = update(theta, opt_state, key, i)
            theta_values.append(theta.copy())
            grad_values.append(grads.copy())
            sigma_values.append(current_sigma)
        
        key, subkey = jax.random.split(key)
        avg_theta = jnp.array(theta_values).mean(axis=0)
        
        return {
            'theta_values': jnp.array(theta_values),
            'grad_values': jnp.array(grad_values),
            'sigma_values': jnp.array(sigma_values),
            'avg_theta': avg_theta,
        }

    avg_thetas_list = []
    boot_thetas_list = []
    ci_lists = []

    key = jax.random.PRNGKey(42)
    for _ in tqdm(range(N_RUNS)):
        key, subkey = jax.random.split(key)
        obs = mvt_norm_simulator(subkey, true_parameters, N_OBS)

        key, subkey = jax.random.split(key)
        res_dict = run_sgd(subkey, init_theta, obs, SIGMA, simulator_fn, LR)

        fitted_theta = res_dict['avg_theta']
        avg_thetas_list.append(fitted_theta)
        boot_thetas = []

        key, subkey = jax.random.split(key)
        for B in range(N_BOOT):
            boot_obs = mvt_norm_simulator(subkey, 
                                    fitted_theta, 
                                    N_OBS)
            
            key, subkey = jax.random.split(subkey)
            boot_res_dict = run_sgd(subkey, 
                                    init_theta, 
                                    boot_obs, 
                                    SIGMA, 
                                    simulator_fn, 
                                    LR)
            boot_thetas.append(boot_res_dict['avg_theta'])

        boot_thetas = jnp.array(boot_thetas)
        ci_low, ci_hi = jnp.percentile(boot_thetas, jnp.array([2.5, 97.5]), axis=0)
        boot_thetas_list.append(boot_thetas)
        ci_lists.append(jnp.stack([ci_low, ci_hi], axis=1))
        
    ci_lists = jnp.array(ci_lists)

    cov = []
    for i in range(N_RUNS):
        cov.append((ci_lists[i,:,0] <= true_parameters) & (ci_lists[i,:,1] >= true_parameters))
        
    mean_cov = jnp.array(cov).mean(axis=0)

    print(f"Coverage: {mean_cov}")

    import pickle

    with open(f'1_mvt_gaussian_cov_n_obs_{N_OBS}_bs.pkl', 'wb') as f:
        pickle.dump({
            'avg_thetas_list': avg_thetas_list,
            'boot_thetas_list': boot_thetas_list,
            'ci_lists': ci_lists,
            'mean_cov': mean_cov
        }, f)
