from src.parametrization import SoftmaxMLP, tree_random_alphas_like
import jax
import jax.numpy as jnp
import optax
from src.alg import Alg, update_fn_min_max
from src.loss import loss_bilinear 
from src.utils import params_dict_and_name, dist_equilibrium
from copy import deepcopy
from src.plot import figure_rps_rand_init
from tqdm import tqdm

NUM_ITER = 1000
SURR_STEP = 0.1
INNER_ITER = 1
NUM_TRAJ = 100

seed = 5

key = jax.random.key(seed)

start_learning_rate = 1



A = jnp.array([[0, -1., 1],
                   [1, 0, -1],
                   [-1, 1, 0]
                   ])
loss_xy = loss_bilinear(A, .2)
equilibrium = [jnp.array([1./3]*3), jnp.array([1./3]*3)]


def inits_from_key(key, num_traj=NUM_TRAJ):
    alpha_key_init, param_key_init = jax.random.split(key, 2)
    alpha_keys = jax.random.split(alpha_key_init, num_traj)
    param_keys = jax.random.split(param_key_init, (num_traj, 2))
    alphas =[]
    for key in alpha_keys:
        alphas.append(tree_random_alphas_like(key, [tuple([jnp.array([4,5]), jnp.array([3,4])]), tuple([jnp.array([4,5]), jnp.array([3,4])])], -1, 1))
    params = [] 
    for key_pair in param_keys[:]:
        params.append([jax.random.normal(k, (5,)) for k in key_pair])
    return alphas, params

alphas, params = inits_from_key(key)

def train_n_steps(params, opt_state, apply_fn, update_fn, num_iter):
    init_opt_state = deepcopy(opt_state)
    @jax.jit
    def f_to_scan(params, x):
        new_params, _, _ = update_fn(params, init_opt_state)
        latent = apply_fn(params)
        return new_params, latent
    return jax.lax.scan(f_to_scan, params, xs=None, length=num_iter)

def run(alg, inner_iter, num_iter, surr_step, start_learning_rate=1, **kwargs):
    
        optimizer = optax.sgd(start_learning_rate)
        opt_state = optimizer.init(params)
        results = []
        for alpha, p in zip(alphas, params):
            model = SoftmaxMLP(alpha)
            update_fn = update_fn_min_max(alg, loss_xy, surr_step, 
                                    model.apply_fn, inner_iter, optimizer, **kwargs)
            update_fn = jax.jit(update_fn)
            _, latent_trajectory = train_n_steps(p, opt_state,model.apply_fn, update_fn, num_iter=num_iter)
            latent_trajectory = jnp.stack(latent_trajectory, axis=1)
            results.append(dist_equilibrium(latent_trajectory, equilibrium))
        
        distances = jnp.array(results)
        mean_traj_distance = jnp.mean(distances, axis=0)
        ci = 1.96 * jnp.std(distances, axis=0)/jnp.sqrt(len(distances[:]))
        return mean_traj_distance, ci 


# run list
experiments = [
params_dict_and_name(alg=Alg.GN, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.LM, inner_iter=1, num_iter=NUM_ITER, lm_reg=1e-2, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.LM, inner_iter=5, num_iter=NUM_ITER, lm_reg=1e-2, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.SURR, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.SURR, inner_iter=10, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.SURR, inner_iter=100, num_iter=NUM_ITER, surr_step=SURR_STEP),
]

distances =[]
cis=[]
names = []  
for e, name in tqdm(experiments):
    distance, ci = run(**e)
    distances.append(distance)
    cis.append(ci)
    names.append(name)

figure_rps_rand_init(distances, cis, equilibrium=[jnp.array([1./3]*3), jnp.array([1./3]*3)], names=names, filename='results/rps_rand_dist.pdf')
