import jax
import pickle
import jax.numpy as jnp
from approxml.simulators import mvt_norm_simulator
from tqdm import tqdm
from itertools import product
from approxml.aml_kde import spsa_aml_iid

for param_dim in [2, 5, 10, 20, 50, 100]:
    for sim_budget in [1e2, 5e2, 1e3, 5e3, 1e4, 5e4]:
        N_OBS = 100
        N_PARAM_DIM = param_dim

        true_parameters = jnp.ones(N_PARAM_DIM)
        obs = mvt_norm_simulator(jax.random.PRNGKey(0), true_parameters, N_OBS)
        obs.shape

        N_ITER = 100
        N_RUNS = 100
        N_SIM_DST = int(sim_budget) // 2
        N_DATA_DIM = N_PARAM_DIM

        def sim_fn(key, mu, n_sim):
            key, subkey = jax.random.split(key)
            sims = mvt_norm_simulator(subkey, mu, n_sim)
            return sims 

        a_grid = [1e-2, 1e-1, 1, 1e1, 1e2, 1e3]
        c_grid = [1e-2, 1e-1, 1, 1e1, 1e2, 1e3]

        results = []

        key = jax.random.PRNGKey(0)
        for a_val, c_val in product(a_grid, c_grid):
            print(f"Pilot run: a: {a_val}, c: {c_val}")
            key, subkey = jax.random.split(key)
            theta_hat = spsa_aml_iid(
                subkey,
                sim_fn,
                obs, 
                jnp.zeros(N_PARAM_DIM), 
                iterations=N_ITER,
                n_sim=N_SIM_DST, 
                a=a_val, 
                c=c_val
            )
            
            key, subkey = jax.random.split(key)
            s_sims = sim_fn(subkey, theta_hat, N_OBS)
            err = jnp.linalg.norm(s_sims - obs, axis=1).mean()
            results.append((a_val, c_val, err))

        results.sort(key=lambda tup: tup[2])
        best_a, best_c, best_err = results[0]

        print("Grid search results (top 5):")
        for a_val, c_val, err in results[:5]:
            print(f"  a={a_val:.3f}, c={c_val:.3f}  -->  error = {err:.4f}")

        print(f"\nBEST:  a = {best_a},  c = {best_c},  error = {best_err:.4f}")

        avg_thetas_list = []
        res_list = []

        key = jax.random.PRNGKey(42)
        for _ in tqdm(range(N_RUNS)):
            key, subkey = jax.random.split(key)
            obs = mvt_norm_simulator(subkey, true_parameters, N_OBS)

            key, subkey = jax.random.split(key)
            theta_hat = spsa_aml_iid(
                subkey,
                sim_fn,
                obs, 
                jnp.zeros(N_PARAM_DIM), 
                iterations=N_ITER,
                n_sim=N_SIM_DST, 
                a=best_a, 
                c=best_c
            )
            avg_thetas_list.append(theta_hat)
            res_list.append({'avg_theta': theta_hat})
            
        avg_thetas_list = jnp.array(avg_thetas_list)
        print(f"Avg error: {jnp.linalg.norm(avg_thetas_list - true_parameters, axis=-1).mean()}")
        
        with open(f'1_mvt_gaussian_aml_iid_sim_bud_{sim_budget}_dim_{N_PARAM_DIM}.pkl', 'wb') as f:
            pickle.dump({
                'res_list': res_list,
                'avg_thetas_list': avg_thetas_list,
                'true_parameters': true_parameters,
                'N_OBS': N_OBS,
                'sim_budget': sim_budget
            }, f)
