from src.parametrization import SoftmaxMLP, tree_random_alphas_like
import jax
import jax.numpy as jnp
from src.alg import Alg
from src.loss import loss_bilinear 
from src.plot import figure_rps
from src.utils import experiment_run_method, params_dict_and_name, run_experiments

NUM_ITER = 1000
SURR_STEP = 0.1
INNER_ITER = 1

seed_alphas = 3
seed_param_init = 150
key = jax.random.key(seed_alphas)
alphas = tree_random_alphas_like(key, [tuple([jnp.array([4,5]), jnp.array([3,4])]), tuple([jnp.array([4,5]), jnp.array([3,4])])], -1, 1)
del key
model = SoftmaxMLP(player_alphas=alphas)


keys = jax.random.split(jax.random.key(seed_param_init), 2)
params = [jax.random.normal(key, (5,)) for key in keys]
start_learning_rate = 1



A = jnp.array([[0, -1., 1],
                   [1, 0, -1],
                   [-1, 1, 0]
                   ])
loss_xy = loss_bilinear(A, .2)


run = experiment_run_method(model, params, loss_xy)

# run list
runs = [
params_dict_and_name(alg=Alg.GN, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.LM, inner_iter=1, num_iter=NUM_ITER, lm_reg=1e-2, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.LM, inner_iter=5, num_iter=NUM_ITER, lm_reg=1e-2, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.SURR, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.SURR, inner_iter=10, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.SURR, inner_iter=100, num_iter=NUM_ITER, surr_step=SURR_STEP),

]

names, trajectories, data_list, times_listnames = run_experiments(run, runs)

figure_rps(trajectories, [jnp.array([1./3]*3), jnp.array([1./3]*3)], data_list, 
           names, filename='results/figure_2_rps.pdf')