from src.parametrization import TinyMLP
import jax.numpy as jnp
from src.alg import Alg
from src.loss import reg_matching_pennies
from src.plot import  plot_benchmark_times, figure_matching_pennies
from src.utils import experiment_run_method, params_dict_and_name, run_experiments



NUM_ITER = 1000
SURR_STEP = 0.01
INNER_ITER = 1
ALG = Alg.GN


model =  TinyMLP([(.5, 1), (.7,1)])
params = [jnp.array([1.25]), jnp.array([2.25])]

start_learning_rate = 1.


loss_xy = reg_matching_pennies(0.75)

run = experiment_run_method(model, params, loss_xy)


# run list
runs = [
params_dict_and_name(alg=Alg.GN, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.GN, inner_iter=5, num_iter=NUM_ITER, surr_step=SURR_STEP),
params_dict_and_name(alg=Alg.GN, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP*10),
params_dict_and_name(alg=Alg.DGN, inner_iter=75, num_iter=NUM_ITER, surr_step=SURR_STEP*10, start_learning_rate=0.001),
params_dict_and_name(alg=Alg.LM, inner_iter=1, num_iter=NUM_ITER, lm_reg=1e-2, surr_step=SURR_STEP*10),
params_dict_and_name(alg=Alg.LM, inner_iter=10, num_iter=NUM_ITER, lm_reg=1e-2, surr_step=SURR_STEP*10),
params_dict_and_name(alg=Alg.SURR, inner_iter=1, num_iter=NUM_ITER, surr_step=SURR_STEP*10),
params_dict_and_name(alg=Alg.SURR, inner_iter=5, num_iter=NUM_ITER, surr_step=SURR_STEP*10),
params_dict_and_name(alg=Alg.SURR, inner_iter=10, num_iter=NUM_ITER, surr_step=SURR_STEP*10),
params_dict_and_name(alg=Alg.SURR, inner_iter=100, num_iter=NUM_ITER, surr_step=SURR_STEP*10),
]

run = experiment_run_method(model, params, loss_xy)

names, trajectories, data_list, times_list = run_experiments(run, runs)


figure_matching_pennies(trajectories, [jnp.array([0.5]), jnp.array([0.5])], 
                        data_list, names, subsets=[[0, 3, 4, 6, 8], [0,3, 4, 6, 8, 9]],
                        filename='results/figure_1_pennies.pdf')

def shorten_name(name):
    alg = name.split("(")[0]
    if "inner" in name:
        inner = name.split("(")[1].split(',')[0].split('inner=')[1]
        return f"{alg}({inner})"
    else:
        return alg

plot_benchmark_times([shorten_name(name) for name in names], times_list, filename='results/time_matching_pennies.pdf')

