# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Brax evaluation
#
# All plots can be reproduce 1:1 if `XLA_FLAGS=--xla_gpu_deterministic_ops=true` is set as environment variable.
# For this, just comment in the first line of the next block. Since this makes the code relatively slow we only used this setting in the evaluation runs and not for hyperparameter search.
#
# For more info on reproducibility in jax see [this discussion](https://github.com/jax-ml/jax/discussions/10674).
#
# To make all code run well during evaluation, please set `n_devices=1` for each strategy even if you have more devices available.
# The current code will break if more devices are being used.

# +
# %env XLA_FLAGS=--xla_gpu_deterministic_ops=true
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform
# %env CUDA_VISIBLE_DEVICES=0

from functools import partial 
import os
import gymnax
import pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scienceplots
from evosax import OpenES, CMA_ES

from sves.benchmarks.brax_ctrl import BraxControl, run_cfg_cma, run_cfg_cma_parallel, run_cfg_cma_single, run_cfg_gf, run_cfg_oes, run_cfg_oes_parallel, run_cfg_oes_single
from sves.strategies import SV_CMA_BB, MC_SVGD, ParallelCMAES, ParallelOpenES, GF_SVGD
from sves.kernels import RBF
from sves.benchmarks.utils.eval_utils import cmapper

# Plotting settings
plt.style.use(['science','no-latex'])
#plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

devices = jax.devices("cuda")
devices
# -

# ## Paper plot
#
# Uses 10 rep & 1000 iterations.

# +
# Hyper params
rng = jax.random.PRNGKey(0)
nrep = 10
npop = 4
subpopsize = 16
num_generations = 1_000 

# Task specifications
env_name = "hopper"
num_layers=2
num_hidden=16
max_episode_len = 1_000
res_path = f"data/paper_res/{env_name}/"
os.makedirs(res_path, exist_ok=True)

# hyperparam tuning
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
kernel_vals = jnp.linspace(1e-3, 30., 10)
sigmas = jnp.linspace(.05, 1., 10)
elite_ratios = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 1., 10)

match env_name:
    case "hopper":
        strategy_clss = [
            lambda ph: GF_SVGD(npop * subpopsize, RBF(kernel_vals[-5]), sigma_init=sigmas_gf[0], lrate_init=adam_lrs[-3], num_iters=num_generations, pholder_params=ph, opt_name="adam"),
            lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[1]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[-4], sigma_init=sigmas[0]),
            lambda ph: MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, sigma_init=sigmas[2], lrate_init=adam_lrs[-3]),
            lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[-2], sigma_init=sigmas[1], pholder_params=ph),
            lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[-1], sigma_init=sigmas[0], pholder_params=ph),
            lambda ph: ParallelOpenES(npop, subpopsize, lrate_init=adam_lrs[-3], sigma_init=sigmas[3], pholder_params=ph),
            lambda ph: OpenES(npop * subpopsize, lrate_init=adam_lrs[-3], sigma_init=sigmas[2], pholder_params=ph)
        ]
    case "walker2d":
        strategy_clss = [
            lambda ph: GF_SVGD(npop * subpopsize, RBF(kernel_vals[3]), sigma_init=sigmas_gf[0], lrate_init=adam_lrs[-3], num_iters=num_generations, pholder_params=ph, opt_name="adam"),
            lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[3]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[-1], sigma_init=sigmas[-3]),
            lambda ph: MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, sigma_init=sigmas[1], lrate_init=adam_lrs[3]),
            lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[-2], sigma_init=sigmas[-1], pholder_params=ph),
            lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[-3], sigma_init=sigmas[3], pholder_params=ph),
            lambda ph: ParallelOpenES(npop, subpopsize, lrate_init=adam_lrs[-2], sigma_init=sigmas[-4], pholder_params=ph),
            lambda ph: OpenES(npop * subpopsize, lrate_init=adam_lrs[3], sigma_init=sigmas[1], pholder_params=ph)
        ]
    case "halfcheetah":
        strategy_clss = [
            lambda ph: GF_SVGD(npop * subpopsize, RBF(kernel_vals[2]), sigma_init=sigmas_gf[0], lrate_init=adam_lrs[3], num_iters=num_generations, pholder_params=ph, opt_name="adam"),
            lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[5]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[-4], sigma_init=sigmas[-4]),
            lambda ph: MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, sigma_init=sigmas[0], lrate_init=adam_lrs[3]),
            lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[4], sigma_init=sigmas[2], pholder_params=ph),
            lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[-3], sigma_init=sigmas[4], pholder_params=ph),
            lambda ph: ParallelOpenES(npop, subpopsize, lrate_init=adam_lrs[-2], sigma_init=sigmas[-3], pholder_params=ph),
            lambda ph: OpenES(npop * subpopsize, lrate_init=adam_lrs[-3], sigma_init=sigmas[1], pholder_params=ph)
        ]
    case "hopper_ablated":
        strategy_clss = [
            lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[0]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[-5], sigma_init=sigmas[0]),
            lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[4], sigma_init=sigmas[1], pholder_params=ph),
            lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[-3], sigma_init=sigmas[0], pholder_params=ph),
        ]
    case "walker2d_ablated":
        num_generations = 1_500  #Run this for longer as described in paper
        strategy_clss = [
            lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[0]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[-4], sigma_init=sigmas[0]),
            lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[-5], sigma_init=sigmas[1], pholder_params=ph),
            lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[-3], sigma_init=sigmas[0], pholder_params=ph),
        ]
    case _:
        raise NotImplementedError("Only the above environments are supported!")

# -

for strategy_cls in strategy_clss:
    seeds = jax.random.split(rng, nrep)
    bench = BraxControl(env_name)
    strategy = strategy_cls(jnp.zeros(1,))   # dummy init to get name and other attriubtes for fname
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    fname = f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    if strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name == "CMA_ES":
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel CMA-ES":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel MC SVGD":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.lrate_init}elite.pkl"
    elif strategy.strategy_name == "OpenES":
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "wb") as handler:
        pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# +
# Plot performance
log_freq = 20
num_rollouts = 16
plot_data = {}
plt.figure(figsize=(6, 4.5))
for strategy_cls in strategy_clss:
    strategy = strategy_cls(jnp.zeros(1,))
    fname = f"data/paper_res3/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    print(strategy.strategy_name)
    if strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name == "CMA_ES":
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel CMA-ES":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel MC SVGD":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.lrate_init}elite.pkl"
    elif strategy.strategy_name == "OpenES":
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "rb") as handler:
        data = pickle.load(handler)
        plot_data[f"{strategy.strategy_name}"] = {}
        for key in data.keys():
            plot_data[f"{strategy.strategy_name}"][key] = jnp.array(data[key])
    
    # Get means
    means = plot_data[strategy.strategy_name]["fitness_max"]
    means_mean = jnp.mean(means, axis=0)
    means_std = jnp.std(means, axis=0) * 1.96 / jnp.sqrt(means.shape[0])
    
    # Plot
    t = jnp.arange(means.shape[1]) * log_freq * num_rollouts * npop * subpopsize / 10**5
    style = cmapper(strategy.strategy_name)
    plt.plot(t, means_mean, label=style["name"], color=style["color"], linestyle=style["linestyle"], linewidth=2.5)
    plt.fill_between(t, means_mean-means_std, means_mean+means_std, alpha=.2, color=style["color"])

#plt.legend()
plt.ylabel("Mean Return")
plt.xlabel(r"Rollouts ($\times 10^5$)")
# -



# ## Hyperparam tuning
#
# As with other methods we do grid search. For this, it makes sense to turn of the deterministic operations because it makes the code very slow.
#
# Also we use only 5rep and 500 iterations as stated in the paper.

# +
# Hyper params
rng = jax.random.PRNGKey(0)
nrep = 5
npop = 4
subpopsize = 16
num_generations = 500

# Task specifications
env_name = "walker2d_ablated"
num_layers=2
num_hidden=16
max_episode_len = 1_000

# hyperparam tuning
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
kernel_vals = jnp.linspace(1e-3, 30., 10)
sigmas = jnp.linspace(.05, 1., 10)
elite_ratios = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 1., 10)
# -

# #### CMA

# +
# Set up
bench = BraxControl(env_name)
strategy = SV_CMA_BB(1, 1, kernel=RBF(), num_iters=1, num_dims=1)  # placeholder just for the attributes

eval_cma = lambda er, sig, kw: run_cfg_cma(rng, npop, subpopsize, er, sig, kw, nrep, num_generations, env_name, num_layers, num_hidden)

for er in elite_ratios[6:7]:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_cma(er, sig, kw))(sigmas)
    )(kernel_vals)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, kw in zip(res_max, res_mean, kernel_vals):
        for rmaxj, rmeanj, sig in zip(rmaxi, rmeani, sigmas):
            results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
            with open(f"data/paper_res/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# Parallel

# +
# Set up
strategy = ParallelCMAES(1, 1, num_dims=1)

eval_cma = lambda er, sig: run_cfg_cma_parallel(rng, npop, subpopsize, er, sig, nrep, env_name, num_layers, num_hidden, num_generations)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda sig: eval_cma(er, sig))(sigmas)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
        results = {"fitness_max": rmaxi, "fitness_mean": rmeani}
        with open(f"data/paper_res3/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# Single

# +
# Set up
strategy = CMA_ES(1, num_dims=1)

eval_cma = lambda er, sig: run_cfg_cma_single(rng, npop, subpopsize, er, sig, nrep, env_name, num_layers, num_hidden, num_generations)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda sig: eval_cma(er, sig))(sigmas)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
        results = {"fitness_max": rmaxi, "fitness_mean": rmeani}
        with open(f"data/paper_res3/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### OpenES

# +
# Set up
strategy = MC_SVGD(1, 2, kernel=RBF(), num_iters=1, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig, kw: run_cfg_oes(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, env_name, num_layers, num_hidden)

for lr in adam_lrs:
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_oes(lr, sig, kw))(sigmas)
    )(kernel_vals)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, kw in zip(res_max, res_mean, kernel_vals):
        for rmaxj, rmeanj, sig in zip(rmaxi, rmeani, sigmas):
            results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
            with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# Parallel

# +
# Set up
strategy = ParallelOpenES(npop, subpopsize, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig: run_cfg_oes_parallel(rng, npop, subpopsize, lr, sig, nrep, env_name, num_layers, num_hidden, num_generations)
res = jax.vmap(lambda sig: jax.vmap(lambda lr: eval_oes(lr, sig))(adam_lrs))(sigmas)
res_mean = jnp.array(res["fitness_mean"])
res_max = jnp.array(res["fitness_max"])

for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
    for rmaxj, rmeanj, lr in zip(rmaxi, rmeani, adam_lrs):
        results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
        with open(f"data/paper_res3/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# -

# Single

# +
# Set up
strategy = OpenES(2, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig: run_cfg_oes_single(rng, npop, subpopsize, lr, sig, nrep, env_name, num_layers, num_hidden, num_generations)
res = jax.vmap(lambda sig: jax.vmap(lambda lr: eval_oes(lr, sig))(adam_lrs))(sigmas)
res_mean = jnp.array(res["fitness_mean"])
res_max = jnp.array(res["fitness_max"])

for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
    for rmaxj, rmeanj, lr in zip(rmaxi, rmeani, adam_lrs):
        results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
        with open(f"data/paper_res3/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# -

# #### GF SVGD

# +
# Set up
bench = BraxControl(env_name)
strategy = GF_SVGD(2, kernel=RBF(), num_iters=1, num_dims=1)

eval_gf = lambda sig, lr, kw: run_cfg_gf(rng, npop, subpopsize, lr, sig, kw, nrep, env_name, num_layers, num_hidden)

for sig in sigmas_gf: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda lr: eval_gf(lr, sig, kw))(adam_lrs)
    )(kernel_vals[5:])
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, kw in zip(res_max, res_mean, kernel_vals[5:]):
        for rmaxj, rmeanj, lr in zip(rmaxi, rmeani, adam_lrs):
            results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
            with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### Plotting

nrep = 10
num_generations = 1000
env_name = "walker2d_ablated"
strategies = [
    #GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    SV_CMA_BB(1, 1, kernel=RBF(.1), num_iters=1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    #MC_SVGD(1, 2, kernel=RBF(1), num_iters=1, num_dims=1, sigma_init=1),
    CMA_ES(1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    #OpenES(2, 1, 1, sigma_init=1),
    #ParallelCMAES(1, 1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    #ParallelOpenES(1, 2, num_dims=1, sigma_init=1)
]

plot_data = {}
for i, strategy in enumerate(strategies):
    #plot_data[strategy.strategy_name] = {}\
    for kw in kernel_vals:
        for er in jnp.concatenate([elite_ratios, adam_lrs]):
            for sig in jnp.concatenate([sigmas, sigmas_gf]):
                if strategy.strategy_name in ["BB_SVGD_ES", "MC SVGD", "GF_SVGD"]:
                    fname = f"data/paper_res/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl"
                    dictkey = f"{strategy.strategy_name}|{er}|{sig}|{kw}"
                elif strategy.strategy_name in ["Parallel CMA-ES", "Parallel MC SVGD", "CMA_ES", "OpenES"]:
                    fname = f"data/paper_res/{env_name}/{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite.pkl"
                    dictkey = f"{strategy.strategy_name}|{er}|{sig}"
                
                # Load
                try:
                    with open(fname, "rb") as handler:
                        data = pickle.load(handler)
                        plot_data[dictkey] = {}
                        for key in data.keys():
                            plot_data[dictkey][key] = jnp.array(data[key])
                
                except Exception as e:
                    #print(e)
                    #print(f"Not found!: {strategy.strategy_name}{er}{sigma}")
                    pass
len(plot_data)

# +
# Iterate solutions and get the best one
val_interval = int(num_generations // 20)
best = {s.strategy_name: {"val": jnp.ones(val_interval) * -10_000} for s in strategies}
num_rollouts = 16
#env_steps = env_params.max_steps_in_episode

for resname in plot_data.keys():
    strategy_name = resname.split("|")[0]
    means = jnp.array(plot_data[resname]["fitness_max"])
    means_mean = jnp.mean(means, axis=0)
    means_std = jnp.std(means, axis=0) * 1.96 / jnp.sqrt(means.shape[0])

    if means_mean[-1] >= 200:
        plt.plot(means_mean, color=cmapper(strategy_name)["color"], label=resname + " " + str(means_mean[-1]))
    
    idx = -1
    if jnp.sum(means_mean[idx:]) >= jnp.sum(best[strategy_name]["val"][idx:]):
        best[strategy_name]["val"] = means_mean
        best[strategy_name]["name"] = strategy_name
        best[strategy_name]["std"] = means_std
        best[strategy_name]["cfg"] = resname

plt.legend()

# Plot performance on val set
t = jnp.arange(means.shape[1]) * val_interval * npop * subpopsize 
plt.figure(figsize=(6, 4.5))
for b in best.keys():
    res = best[b]
    if res["name"]:
        style = cmapper(res["name"])
        plt.plot(t, res["val"], label=res["name"], color=style["color"], linestyle=style["linestyle"], linewidth=2)
        plt.fill_between(t, res["val"]-res["std"], res["val"]+res["std"], alpha=.2, color=style["color"])

#plt.ylim(-1500, -100)
plt.ylabel("Mean Return")
plt.xlabel("Rollouts")
plt.legend();
#plt.savefig(f"data/lr_eval/{dataset}{nrep}rep{num_generations}iter{npop}pop{popsize}val.svg", format="svg")

print("Best results at:")
for s in best.keys():
    print(best[s]["cfg"], best[s]["val"][-1])
# -

kernel_vals, elite_ratios, sigmas, adam_lrs, sigmas_gf


