from copy import deepcopy
import jax
import jax.numpy as jnp
import optax
import time
from src.alg import update_fn_min_max, Alg

def train(params, model, update_fn, opt_state, num_iter, restart_opt=True, skip_compile=True):
    data = []
    running_time = []
    init_opt_state = deepcopy(opt_state)
    p_trajectory = [params]
    trajectory = [model.apply_fn(params)]
    running_time.append(0.)
    start_time = time.time()
    for i in range(num_iter):
        model_out = model.apply_fn(params)
        if restart_opt:
            opt_state = init_opt_state
        params, opt_state, inner_loop_data = update_fn(params, opt_state)
        data.append(inner_loop_data)
        trajectory.append(model_out)
        p_trajectory.append(params)
        if i == 0 and skip_compile:
            start_time = time.time()
        running_time.append(time.time()-start_time)
    return trajectory, p_trajectory, data, running_time

def experiment_run_method(model, params, loss_xy):

    def run(alg, inner_iter, num_iter, surr_step, start_learning_rate=1, benchmark=False, **kwargs):
        optimizer = optax.sgd(start_learning_rate)
        opt_state = optimizer.init(params)
        update_fn = update_fn_min_max(alg, loss_xy, surr_step, 
                                    model.apply_fn, inner_iter, optimizer, **kwargs)
        update_fn = jax.jit(update_fn)
        results = train(params, model, update_fn, opt_state, num_iter=num_iter)
        
        time_benchmark = None
        if benchmark:
            import timeit
            def update():
                jax.block_until_ready(update_fn(params, opt_state))
            
            time_benchmark = timeit.timeit(update, number=10000)
        return results, time_benchmark

    return run

def params_dict_and_name(**params):
    name = params['alg'].value
    inner_iter = params['inner_iter']
    
    surr_step = params['surr_step']
    if name == Alg.GN and inner_iter == 1:
        return params, f"PHGD($\eta$={surr_step})"
    elif name == Alg.SURR and inner_iter == 1:
        return params, f"GDA($\eta$={surr_step})"

    start_learning_rate = params.get('start_learning_rate', None)
    lm_reg = params.get('lm_reg', None)
    name += f"(inner={inner_iter:d},$\eta$={surr_step}" 
    if start_learning_rate:
        name += f",$\eta_{{alg}}$={start_learning_rate}"
    if lm_reg:
        name += f",$\lambda$={lm_reg}"
    name += ")"
    return params, name


def run_experiments(run_method, runs):
    names = []
    trajectories = []
    data_list = []
    times_list = []
    for d, name in runs:
        (latents, parameters, data, loop_time), time_benchmark = run_method(**d, benchmark=True)
        trajectories.append((parameters, latents))
        data_list.append(data)
        names.append(name)
        times_list.append(time_benchmark)
    return names, trajectories, data_list, times_list

@jax.jit
def dist_equilibrium(latents_list, equilibrium):
  latents = jnp.array(latents_list).squeeze()
  def distance(p):
    return 0.5*jnp.sum(jnp.abs(p-jnp.array(equilibrium).squeeze())**2)
  distance_v = jax.jit(jax.vmap(distance, in_axes=0))
  return distance_v(latents)