import jax
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm, cross_val_sm
from approxml.simulators import mvt_norm_simulator
from tqdm import tqdm
import optax

for N_OBS in [1, 5, 10, 50, 100, 200,300,400,500]:
    N_PARAM_DIM = 5

    true_parameters = jnp.ones(N_PARAM_DIM)
    obs = mvt_norm_simulator(jax.random.PRNGKey(0), 
                             true_parameters, 
                             N_OBS)

    N_ITER = 100
    N_RUNS = 100 
    N_PROP = 50000    
    N_SIM_DST = 1
    N_DATA_DIM = N_PARAM_DIM
    MOD_SIGMA = jnp.eye(N_PARAM_DIM)
    LR = 1e-3

    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    init_theta = jnp.zeros(N_PARAM_DIM)
    simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)

    init_sigma, mean_val_losses = cross_val_sm([0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5],
                init_theta,
                simulator_fn,
                N_PROP,
                N_SIM_DST,
                obs,
                subkey,
                N_PARAM_DIM,
                N_DATA_DIM)


    def robbins_monro_schedule(
        init_value = 1.0,
        alpha = 0.75,
        offset = 10.0
    ):
        def schedule(count: int) -> jnp.ndarray:
            t = offset + (count + 1)
            return init_value / (t ** alpha)
        
        return schedule

    def get_95_ci(avg_theta, obs_aug, subkey):
        current_prop_cov = 0.1 * jnp.eye(N_PARAM_DIM)
        current_gen_sim_fn = partial(gen_simulation_samples,
            simulator_fn=simulator_fn,
            prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
            n_prop=100000,
            n_sim_dst=1)
        
        current_grad_fn = partial(fit_linear_sm,
            gen_sim_fn=current_gen_sim_fn,
            grad_log_prop_fn=partial(grad_log_normal, cov=current_prop_cov),
            n_sim_dst=1,
            n_prop=100000)
        
        W, _, _, _ = current_grad_fn(subkey, avg_theta)
        grads = jnp.einsum('mk,ik->im', W.T, obs_aug)
        grads = jax.vmap(jnp.outer, in_axes=(0, 0))(grads, grads)
        asy_cov = grads.sum(0) 
        asy_cov = jnp.linalg.inv(asy_cov)
        ci = jnp.sqrt(jnp.diag(asy_cov)) * 1.96
        ci = jnp.stack([avg_theta - ci, avg_theta + ci], axis=1)
        return ci, asy_cov, avg_theta

    def run_sgd(key, 
                theta_init, 
                obs,
                sigma_init,
                simulator_fn,
                learning_rate
                ):
        sigma_final = 0.01
        obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1) 

        def get_sigma(step):
            progress = step / N_ITER
            return sigma_init * (1.0 - progress) + sigma_final * progress

        schedule_fn = robbins_monro_schedule(init_value=learning_rate)
        optimizer = optax.chain(
            optax.scale_by_schedule(schedule_fn),
            optax.rmsprop(1.0)
        )
        opt_state = optimizer.init(theta_init)

        @jax.jit
        def update(params, opt_state, key, step):
            key, subkey = jax.random.split(key)
            
            current_sigma = get_sigma(step)
            current_prop_cov = current_sigma * jnp.eye(N_PARAM_DIM)
            
            current_gen_sim_fn = partial(gen_simulation_samples,
                simulator_fn=simulator_fn,
                prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
                n_prop=N_PROP,
                n_sim_dst=N_SIM_DST)
            
            current_grad_fn = partial(fit_linear_sm,
                gen_sim_fn=current_gen_sim_fn,
                grad_log_prop_fn=partial(grad_log_normal, cov=current_prop_cov),
                n_sim_dst=N_SIM_DST,
                n_prop=N_PROP)
            
            W, _, _, _ = current_grad_fn(subkey, params)
            grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
            grads = - grads 
            updates, opt_state = optimizer.update(grads, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state, key, grads, current_sigma

        theta_values = []
        grad_values = []
        sigma_values = []

        theta = theta_init.copy()
        for i in range(N_ITER):
            theta, opt_state, key, grads, current_sigma = update(theta, opt_state, key, i)
            theta_values.append(theta.copy())
            grad_values.append(grads.copy())
            sigma_values.append(current_sigma)
        
        key, subkey = jax.random.split(key)
        avg_theta = jnp.array(theta_values).mean(axis=0)
        ci, asy_cov, avg_theta = get_95_ci(avg_theta, obs_aug, subkey)
        
        return {
            'theta_values': jnp.array(theta_values),
            'grad_values': jnp.array(grad_values),
            'sigma_values': jnp.array(sigma_values),
            'avg_theta': avg_theta,
            'ci': ci,
            'asy_cov': asy_cov
        }

    avg_thetas_list = []
    ci_thetas_list = []
    res_list = []

    key = jax.random.PRNGKey(42)
    for _ in tqdm(range(N_RUNS)):
        key, subkey = jax.random.split(key)
        obs = mvt_norm_simulator(subkey, true_parameters, N_OBS)

        key, subkey = jax.random.split(key)
        res_dict = run_sgd(subkey, init_theta, obs, init_sigma, simulator_fn, LR)
        res_list.append(res_dict)
        avg_thetas_list.append(res_dict['avg_theta'])
        ci_thetas_list.append(res_dict['ci'])

    avg_thetas_list = jnp.array(avg_thetas_list)
    ci_thetas_list = jnp.array(ci_thetas_list)


    cov = []
    for i in range(N_RUNS):
        cov.append((ci_thetas_list[i,:,0] <= true_parameters) & (ci_thetas_list[i,:,1] >= true_parameters))

    mean_cov = jnp.array(cov).mean(axis=0)
    print(f"Coverage: {mean_cov}")

    import pickle

    with open(f'1_mvt_gaussian_cov_n_obs_{N_OBS}.pkl', 'wb') as f:
        pickle.dump({
            'res_list': res_list,
            'mean_cov': mean_cov
        }, f)
