import jax
import pickle
import jax.numpy as jnp
from functools import partial
from approxml.utils import gen_simulation_samples
from approxml.scorematching import score_matching_loss
from approxml.simulators import mvt_norm_simulator
from approxml.optimisers import nn_fit
from approxml.utils import NeuralNetwork
from tqdm import tqdm
import optax
from itertools import product

for param_dim in [2, 5, 10, 20, 50, 100]:
    for sim_budget in [1e2, 5e2, 1e3, 5e3, 1e4, 5e4]:
        N_OBS = 100
        N_PARAM_DIM = param_dim

        true_parameters = jnp.ones(N_PARAM_DIM)
        obs = mvt_norm_simulator(jax.random.PRNGKey(0), true_parameters, N_OBS)

        N_ITER = 100
        N_RUNS = 100
        N_PROP = int(sim_budget)   
        N_SIM_DST = 1
        N_DATA_DIM = N_PARAM_DIM
        MOD_SIGMA = jnp.eye(N_PARAM_DIM)

        key = jax.random.PRNGKey(0)
        key, subkey = jax.random.split(key)
        init_theta = jnp.zeros(N_PARAM_DIM)
        simulator_fn = partial(mvt_norm_simulator, cov=MOD_SIGMA)


        def robbins_monro_schedule(
            init_value = 1.0,
            alpha = 1.0,
            offset = 10.0
        ):
            def schedule(count: int) -> jnp.ndarray:
                t = offset + (count + 1)
                return init_value / (t ** alpha)
            
            return schedule


        def _grad_fn_impl(theta_t,
                        key, 
                        lr, 
                        obs, 
                        model, 
                        gen_sim_fn, 
                        sm_loss_fn,
                        n_iter):
            nn_opt = optax.adam(lr)
            final_model, loss_vals, thetas_q, sims_q, key = nn_fit(model,
                                                            nn_opt,
                                                            theta_t,
                                                            key,
                                                            gen_sim_fn,
                                                            sm_loss_fn,
                                                            n_iter=n_iter)
            
            return jax.vmap(final_model)(obs).sum(0), final_model

        grad_fn = jax.jit(lambda theta_t, key, lr, obs, model, gen_sim_fn, sm_loss_fn, n_iter: _grad_fn_impl(theta_t, key, lr, obs, model, gen_sim_fn, sm_loss_fn, n_iter), static_argnums=(5, 6, 7))

        def run_sgd(key, 
                    theta_init, 
                    obs,
                    sigma,
                    simulator_fn,
                    learning_rate
                    ):
            schedule_fn = robbins_monro_schedule(init_value=learning_rate)
            optimizer = optax.chain(
                optax.scale_by_schedule(schedule_fn),
                optax.adam(1.0)  
            )
            opt_state = optimizer.init(theta_init)

            
            @jax.jit
            def update(model, params, opt_state, key, step):
                key, subkey = jax.random.split(key)
                
                current_prop_cov = sigma * jnp.eye(N_PARAM_DIM)
                
                def sm_loss_fn(model, sims_q, thetas_q, theta_t):
                    pred_q = jax.vmap(jax.vmap(model))(sims_q)
                    return score_matching_loss(pred_q, thetas_q, theta_t, N_PROP, N_SIM_DST, current_prop_cov)
                
                sm_loss_fn = jax.jit(sm_loss_fn)
                current_gen_sim_fn = partial(gen_simulation_samples,
                    simulator_fn=simulator_fn,
                    prop_sim_fn=partial(mvt_norm_simulator, cov=current_prop_cov),
                    n_prop=N_PROP,
                    n_sim_dst=N_SIM_DST)
                
                current_grad_fn = partial(grad_fn,
                    lr=1e-2,
                    obs=obs,
                    model=model,
                    gen_sim_fn=current_gen_sim_fn,
                    sm_loss_fn=sm_loss_fn,
                    n_iter=10)                
                grads, model = current_grad_fn(params, subkey)
                grads = - grads  
                updates, opt_state = optimizer.update(grads, opt_state, params)
                params = optax.apply_updates(params, updates)
                return params, opt_state, key, grads, sigma, model

            theta_values = []
            grad_values = []
            sigma_values = []

            theta = theta_init.copy()
            
            key, subkey = jax.random.split(key)
            model = NeuralNetwork(key=subkey, input_dim=N_DATA_DIM, output_dim=N_PARAM_DIM)
            
            for i in range(N_ITER):
                theta, opt_state, key, grads, sigma, model = update(model, theta, opt_state, key, i)
                theta_values.append(theta.copy())  
                grad_values.append(grads.copy())
                sigma_values.append(sigma)

            avg_theta = jnp.array(theta_values)[-50:,:].mean(axis=0)
            
            return {
                'theta_values': jnp.array(theta_values),
                'grad_values': jnp.array(grad_values),
                'sigma_values': jnp.array(sigma_values),
                'avg_theta': avg_theta,
            }

        lr_grid = [1e-3, 1e-2, 1e-1]
        sigma_grid = [1e-2, 1e-1, 1.0]

        results = []

        key = jax.random.PRNGKey(0)
        for lr_val, sigma_val in product(lr_grid, sigma_grid):
            print(f"Pilot run: lr: {lr_val}, sigma: {sigma_val}")
            key, subkey = jax.random.split(key)
            res_dict = run_sgd(
                subkey,
                init_theta,
                obs,
                sigma_val,
                simulator_fn,
                lr_val
            )
            
            key, subkey = jax.random.split(key)
            s_sims = simulator_fn(subkey, res_dict['avg_theta'], N_OBS)
            err = jnp.linalg.norm(s_sims - obs, axis=1).mean()
            results.append((lr_val, sigma_val, err))

        results.sort(key=lambda tup: tup[2])
        best_lr, best_sigma, best_err = results[0]

        print("Grid search results (top 5):")
        for lr_val, sigma_val, err in results[:5]:
            print(f"  lr={lr_val:.3f}, sigma={sigma_val:.3f}  -->  error = {err:.4f}")

        print(f"\nBEST:  lr = {best_lr},  sigma = {best_sigma},  error = {best_err:.4f}")

        avg_thetas_list = []
        res_list = []

        key = jax.random.PRNGKey(42)
        for _ in tqdm(range(N_RUNS)):
            key, subkey = jax.random.split(key)
            obs = mvt_norm_simulator(subkey, true_parameters, N_OBS)

            key, subkey = jax.random.split(key)
            res_dict = run_sgd(subkey, init_theta, obs, best_sigma, simulator_fn, best_lr)
            res_list.append(res_dict)
            avg_thetas_list.append(res_dict['avg_theta'])

        avg_thetas_list = jnp.array(avg_thetas_list)
        print(f"Avg error: {jnp.linalg.norm(avg_thetas_list - true_parameters, axis=-1).mean()}")

        with open(f'1_mvt_gaussian_NN_small_sim_bud_{sim_budget}_dim_{N_PARAM_DIM}.pkl', 'wb') as f:
            pickle.dump({
                'res_list': res_list,
                'avg_thetas_list': avg_thetas_list,
                'true_parameters': true_parameters,
                'N_OBS': N_OBS,
                'sim_budget': sim_budget
            }, f)
