import jax
import pickle
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples, grad_log_normal
from approxml.scorematching import fit_linear_sm
from approxml.simulators import mvt_norm_simulator
from tqdm import tqdm
import optax
from itertools import product

for param_dim in [2, 5, 10, 20, 50, 100]:
    for sim_budget in [1e2, 5e2, 1e3, 5e3, 1e4, 5e4]:
        N_OBS = 100
        N_PARAM_DIM = param_dim

        true_parameters = jnp.ones(N_PARAM_DIM)
        obs = mvt_norm_simulator(jax.random.PRNGKey(0), true_parameters, N_OBS)

        N_ITER = 100
        N_RUNS = 100
        N_PROP = int(sim_budget)   
        N_SIM_DST = 1
        N_DATA_DIM = N_PARAM_DIM
        MOD_SIGMA = jnp.eye(N_PARAM_DIM)

        key = jax.random.PRNGKey(0)
        key, subkey = jax.random.split(key)
        init_theta = jnp.zeros(N_PARAM_DIM)
        simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)

        def robbins_monro_schedule(
            init_value = 1.0,
            alpha = 1.0,
            offset = 10.0
        ):
            def schedule(count: int) -> jnp.ndarray:
                t = offset + (count + 1)
                return init_value / (t ** alpha)
            
            return schedule
        
        def run_sgd(key, 
                    theta_init, 
                    obs,
                    sigma,
                    simulator_fn,
                    learning_rate
                    ):
            obs_aug = jnp.concatenate([obs, jnp.ones_like(obs[..., :1])], axis=-1) 

            schedule_fn = robbins_monro_schedule(init_value=learning_rate)
            optimizer = optax.chain(
                optax.scale_by_schedule(schedule_fn),
                optax.adam(1.0)  
            )
            opt_state = optimizer.init(theta_init)

            @jax.jit
            def update(params, opt_state, key, step):
                key, subkey = jax.random.split(key)
                current_prop_cov = sigma * jnp.eye(N_PARAM_DIM)
                current_gen_sim_fn = partial(gen_simulation_samples,
                    simulator_fn=simulator_fn,
                    prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
                    n_prop=N_PROP,
                    n_sim_dst=N_SIM_DST)
                
                current_grad_fn = partial(fit_linear_sm,
                    gen_sim_fn=current_gen_sim_fn,
                    grad_log_prop_fn=partial(grad_log_normal, cov=current_prop_cov),
                    n_sim_dst=N_SIM_DST,
                    n_prop=N_PROP,
                    lamb=1e0)
                
                W, _, _, _ = current_grad_fn(subkey, params)
                grads = jnp.einsum('mk,ik->im', W.T, obs_aug).sum(0)
                grads = - grads 
                updates, opt_state = optimizer.update(grads, opt_state, params)
                params = optax.apply_updates(params, updates)
                return params, opt_state, key, grads, sigma

            theta_values = []
            grad_values = []
            sigma_values = []

            theta = theta_init.copy()
            for i in range(N_ITER):
                theta, opt_state, key, grads, sigma = update(theta, opt_state, key, i)
                theta_values.append(theta.copy())  
                grad_values.append(grads.copy())
                sigma_values.append(sigma)

            avg_theta = jnp.array(theta_values)[-50:,:].mean(axis=0)
            
            return {
                'theta_values': jnp.array(theta_values),
                'grad_values': jnp.array(grad_values),
                'sigma_values': jnp.array(sigma_values),
                'avg_theta': avg_theta,
            }

        lr_grid = [1e-3, 1e-2, 1e-1, 1e1, 5e1]
        sigma_grid = [1e-2, 1e-1, 1.0, 1e1, 5e1]

        results = []

        key = jax.random.PRNGKey(0)
        for lr_val, sigma_val in product(lr_grid, sigma_grid):
            print(f"Pilot run: lr: {lr_val}, sigma: {sigma_val}")
            key, subkey = jax.random.split(key)
            res_dict = run_sgd(
                subkey,
                init_theta,
                obs,
                sigma_val,
                simulator_fn,
                lr_val
            )
            
            key, subkey = jax.random.split(key)
            s_sims = simulator_fn(subkey, res_dict['avg_theta'], N_OBS)
            err = jnp.linalg.norm(s_sims - obs, axis=1).mean()
            results.append((lr_val, sigma_val, err))

        results.sort(key=lambda tup: tup[2])
        best_lr, best_sigma, best_err = results[0]

        print("Grid search results (top 5):")
        for lr_val, sigma_val, err in results[:5]:
            print(f"  lr={lr_val:.3f}, sigma={sigma_val:.3f}  -->  error = {err:.4f}")

        print(f"\nBEST:  lr = {best_lr},  sigma = {best_sigma},  error = {best_err:.4f}")

        avg_thetas_list = []
        res_list = []

        key = jax.random.PRNGKey(42)
        for _ in tqdm(range(N_RUNS)):
            key, subkey = jax.random.split(key)
            obs = mvt_norm_simulator(subkey, true_parameters, N_OBS)

            key, subkey = jax.random.split(key)
            res_dict = run_sgd(subkey, init_theta, obs, best_sigma, simulator_fn, best_lr)
            res_list.append(res_dict)
            avg_thetas_list.append(res_dict['avg_theta'])

        avg_thetas_list = jnp.array(avg_thetas_list)
        print(f"Avg error: {jnp.linalg.norm(avg_thetas_list - true_parameters, axis=-1).mean()}")

        with open(f'1_mvt_gaussian_sim_bud_{sim_budget}_dim_{N_PARAM_DIM}.pkl', 'wb') as f:
            pickle.dump({
                'res_list': res_list,
                'avg_thetas_list': avg_thetas_list,
                'true_parameters': true_parameters,
                'N_OBS': N_OBS,
                'sim_budget': sim_budget
            }, f)
