# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# +
# %env CUDA_VISIBLE_DEVICES=0

import os
import pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scienceplots
from evosax import OpenES, CMA_ES

from sves.benchmarks.classic_control import ClassicControl, run_cfg_cma, run_cfg_cma_parallel, run_cfg_cma_single, run_cfg_gf, run_cfg_oes, run_cfg_oes_parallel, run_cfg_oes_single
from sves.strategies import SV_CMA_BB, MC_SVGD, ParallelCMAES, ParallelOpenES, GF_SVGD
from sves.kernels import RBF
from sves.benchmarks.utils.eval_utils import cmapper

plt.style.use(['science','no-latex'])
#plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

devices = jax.devices("cuda")
devices

# +
env_name = "CartPole-v1"

# Hyper params
rng = jax.random.PRNGKey(0)
nrep = 10   # 10 for eval
npop = 4
subpopsize = 16
num_generations = 200
num_layers=2
num_hidden=16

# hyperparam tuning
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
kernel_vals = jnp.linspace(1e-3, 30., 10)
sigmas = jnp.linspace(.05, 1., 10)
elite_ratios = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 1., 10)

# Get storage path
res_path = f"data/paper_res/{env_name}/"
os.makedirs(res_path, exist_ok=True)

# Pick methods
if env_name == "CartPole-v1":
    max_episode_len = None
    strategy_clss = [
        lambda ph: GF_SVGD(npop * subpopsize, RBF(kernel_vals[4]), sigma_init=sigmas_gf[4], lrate_init=adam_lrs[-3], num_iters=num_generations, pholder_params=ph, opt_name="adam"),
        lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[3], sigma_init=sigmas[-2]),
        lambda ph: MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, sigma_init=sigmas[-1], lrate_init=adam_lrs[-1]),
        lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[-4], sigma_init=sigmas[-2], pholder_params=ph),
        lambda ph: OpenES(npop * subpopsize, lrate_init=adam_lrs[-2], sigma_init=sigmas[2], pholder_params=ph),
        lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[3], sigma_init=sigmas[-4], pholder_params=ph),
        lambda ph: ParallelOpenES(npop, subpopsize, sigma_init=sigmas[3], pholder_params=ph, lrate_init=adam_lrs[-1])
    ]
elif env_name == "Pendulum-v1":
    max_episode_len = None
    strategy_clss = [
        lambda ph: GF_SVGD(npop * subpopsize, RBF(kernel_vals[-5]), sigma_init=sigmas_gf[3], lrate_init=adam_lrs[3], num_iters=num_generations, pholder_params=ph, opt_name="adam"),
        lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[1]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[2], sigma_init=sigmas[4]),
        lambda ph: MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, sigma_init=sigmas[0], lrate_init=adam_lrs[-3]),
        lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[3], sigma_init=sigmas[-2], pholder_params=ph),
        lambda ph: OpenES(npop * subpopsize, lrate_init=adam_lrs[-1], sigma_init=sigmas[-2], pholder_params=ph),
        lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[2], sigma_init=sigmas[4], pholder_params=ph),
        lambda ph: ParallelOpenES(npop, subpopsize, sigma_init=sigmas[-3], pholder_params=ph, lrate_init=adam_lrs[-2])
    ]
elif env_name == "MountainCarContinuous-v0":
    max_episode_len = 500
    strategy_clss = [
        lambda ph: GF_SVGD(npop * subpopsize, RBF(kernel_vals[-3]), sigma_init=sigmas_gf[-5], lrate_init=adam_lrs[3], num_iters=num_generations, pholder_params=ph, opt_name="adam"),
        lambda ph: SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, elite_ratio=elite_ratios[2], sigma_init=sigmas[-4]),
        lambda ph: MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[-1]), num_iters=num_generations, pholder_params=ph, sigma_init=sigmas[-4], lrate_init=adam_lrs[-1]),
        lambda ph: CMA_ES(npop * subpopsize, elite_ratio=elite_ratios[2], sigma_init=sigmas[-1], pholder_params=ph),
        lambda ph: OpenES(npop * subpopsize, sigma_init=sigmas[-1], pholder_params=ph, lrate_init=adam_lrs[-1]),
        lambda ph: ParallelCMAES(npop, subpopsize, elite_ratio=elite_ratios[2], sigma_init=sigmas[4], pholder_params=ph),
        lambda ph: ParallelOpenES(npop, subpopsize, sigma_init=sigmas[4], pholder_params=ph, lrate_init=adam_lrs[-2])
    ]
else:
    raise NotImplementedError("Must be one of the above environments!")
# -

for strategy_cls in strategy_clss:
    seeds = jax.random.split(rng, nrep)
    bench = ClassicControl(env_name, max_episode_len)
    strategy = strategy_cls(jnp.zeros(1,))
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy_cls, num_layers, num_hidden, num_generations)
    )(seeds)
    fname = f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    if strategy.strategy_name == "OG_SVGD":
        fname += f"{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name == "CMA_ES":
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel CMA-ES":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel MC SVGD":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.lrate_init}elite.pkl"
    elif strategy.strategy_name == "OpenES":
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "wb") as handler:
        pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# +
# Plot performance on test set
log_freq = 5
num_rollouts = 16
plot_data = {}
plt.figure(figsize=(6, 4.5))
for strategy_cls in strategy_clss:
    strategy = strategy_cls(jnp.zeros(1,))
    fname = f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    print(strategy.strategy_name)
    if strategy.strategy_name == "OG_SVGD":
        fname += f"{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name == "MC SVGD":
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name == "CMA_ES":
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel CMA-ES":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.elite_ratio}elite.pkl"
    elif strategy.strategy_name == "Parallel MC SVGD":
        fname += f"{strategy.base_strategy.sigma_init}sigma{strategy.base_strategy.lrate_init}elite.pkl"
    elif strategy.strategy_name == "OpenES":
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "rb") as handler:
        data = pickle.load(handler)
        plot_data[f"{strategy.strategy_name}"] = {}
        for key in data.keys():
            plot_data[f"{strategy.strategy_name}"][key] = jnp.array(data[key])
    
    # Get means
    means = plot_data[strategy.strategy_name]["fitness_max"]
    means_mean = jnp.mean(means, axis=0)
    means_std = jnp.std(means, axis=0) * 1.96 / jnp.sqrt(means.shape[0])
    
    # Plot
    t = jnp.arange(means.shape[1]) * log_freq * num_rollouts * npop * subpopsize / 1000
    style = cmapper(strategy.strategy_name)
    plt.plot(t, means_mean, label=style["name"], color=style["color"], linestyle=style["linestyle"], linewidth=2.5)
    plt.fill_between(t, means_mean-means_std, means_mean+means_std, alpha=.2, color=style["color"])

plt.ylabel("Mean Return")
plt.xlabel(r"Rollouts ($\times 10^3$)")
# -



# ## Grid search

# +
# Hyper params
rng = jax.random.PRNGKey(0)
nrep = 5
npop = 4
subpopsize = 16
num_generations = 200

# hyperparam tuning
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
kernel_vals = jnp.linspace(1e-3, 30., 10)
sigmas = jnp.linspace(.05, 1., 10)
elite_ratios = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 1., 10)

# Task specifications
env_name = "Pendulum-v1"
num_layers=2
num_hidden=16
max_episode_len = None if env_name != "MountainCarContinuous-v0" else 500
# -

# ### CMA

# +
# Set up
strategy = SV_CMA_BB(1, 1, kernel=RBF(), num_iters=1, num_dims=1)

eval_cma = lambda er, sig, kw: run_cfg_cma(rng, npop, subpopsize, er, sig, kw, nrep, num_generations, env_name, max_episode_len, num_layers, num_hidden)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_cma(er, sig, kw))(sigmas)
    )(kernel_vals)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, kw in zip(res_max, res_mean, kernel_vals):
        for rmaxj, rmeanj, sig in zip(rmaxi, rmeani, sigmas):
            results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
            with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# Parallel

# +
# Set up
strategy = ParallelCMAES(1, 1, num_dims=1)

eval_cma = lambda er, sig: run_cfg_cma_parallel(rng, npop, subpopsize, er, sig, nrep, env_name, max_episode_len, num_layers, num_hidden)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda sig: eval_cma(er, sig))(sigmas)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
        results = {"fitness_max": rmaxi, "fitness_mean": rmeani}
        with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# Single

# +
# Set up
strategy = CMA_ES(1, num_dims=1)

eval_cma = lambda er, sig: run_cfg_cma_single(rng, npop, subpopsize, er, sig, nrep, env_name, max_episode_len, num_layers, num_hidden, num_generations)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda sig: eval_cma(er, sig))(sigmas)
    res_mean = jnp.array(res["fitness_mean"])
    res_max = jnp.array(res["fitness_max"])
    
    for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
        results = {"fitness_max": rmaxi, "fitness_mean": rmeani}
        with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# ### OpenES

# +
# Set up
strategy = MC_SVGD(1, 2, kernel=RBF(), num_iters=1, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig, kw: run_cfg_oes(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, env_name, max_episode_len, num_layers, num_hidden)
res = jax.vmap(lambda kw: jax.vmap(
    lambda sig: jax.vmap(
        lambda lr: eval_oes(lr, sig, kw))(adam_lrs)
    )(sigmas)
)(kernel_vals)
res_mean = jnp.array(res["fitness_mean"])
res_max = jnp.array(res["fitness_max"])

for rmaxi, rmeani, kw in zip(res_max, res_mean, kernel_vals):
    for rmaxj, rmeanj, sig in zip(rmaxi, rmeani, sigmas):
        for rmaxk, rmeank, lr in zip(rmaxj, rmeanj, adam_lrs):
            results = {"fitness_max": rmaxk, "fitness_mean": rmeank}
            with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# -

# Parallel

# +
# Set up
strategy = ParallelOpenES(1, 2, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig: run_cfg_oes_parallel(rng, npop, subpopsize, lr, sig, nrep, env_name, max_episode_len, num_layers, num_hidden, num_generations)
res = jax.vmap(lambda sig: jax.vmap(lambda lr: eval_oes(lr, sig))(adam_lrs))(sigmas)
res_mean = jnp.array(res["fitness_mean"])
res_max = jnp.array(res["fitness_max"])

for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
    for rmaxj, rmeanj, lr in zip(rmaxi, rmeani, adam_lrs):
        results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
        with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# -

# Single

# +
# Set up
strategy = OpenES(2, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig: run_cfg_oes_single(rng, npop, subpopsize, lr, sig, nrep, env_name, max_episode_len, num_layers, num_hidden, num_generations)
res = jax.vmap(lambda sig: jax.vmap(lambda lr: eval_oes(lr, sig))(adam_lrs))(sigmas)
res_mean = jnp.array(res["fitness_mean"])
res_max = jnp.array(res["fitness_max"])

for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas):
    for rmaxj, rmeanj, lr in zip(rmaxi, rmeani, adam_lrs):
        results = {"fitness_max": rmaxj, "fitness_mean": rmeanj}
        with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# -

# #### GF SVGD

# +
# Set up
strategy = GF_SVGD(2, kernel=RBF(), num_iters=1, num_dims=1)

eval_gf = lambda sig, lr, kw: run_cfg_gf(rng, npop, subpopsize, sig, lr, kw, nrep, env_name, max_episode_len, num_layers, num_hidden, num_generations)
res = jax.vmap(lambda sig: jax.vmap(lambda kw: jax.vmap(lambda lr: eval_gf(sig, lr, kw))(adam_lrs))(kernel_vals))(sigmas_gf)
res_mean = jnp.array(res["fitness_mean"])
res_max = jnp.array(res["fitness_max"])

for rmaxi, rmeani, sig in zip(res_max, res_mean, sigmas_gf):
    for rmaxj, rmeanj, kw in zip(rmaxi, rmaxi, kernel_vals):
        for rmaxk, rmeank, lr in zip(rmaxj, rmeanj, adam_lrs):
            results = {"fitness_max": rmaxk, "fitness_mean": rmeank}
            with open(f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# -

# #### Plotting

# +
strategies = [
    GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    SV_CMA_BB(1, 1, kernel=RBF(.1), num_iters=1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    MC_SVGD(1, 2, kernel=RBF(1), num_iters=1, num_dims=1, sigma_init=1),
    CMA_ES(1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    OpenES(2, 1, 1, sigma_init=1),
    ParallelCMAES(1, 1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    ParallelOpenES(1, 2, num_dims=1, sigma_init=1)
]

plot_data = {}
for i, strategy in enumerate(strategies):
    #plot_data[strategy.strategy_name] = {}\
    for kw in kernel_vals:
        for er in jnp.concatenate([elite_ratios, adam_lrs]):
            for sig in jnp.concatenate([sigmas, sigmas_gf]):
                if strategy.strategy_name in ["BB_SVGD_ES", "MC SVGD", "GF_SVGD"]:
                    fname = f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl"
                    dictkey = f"{strategy.strategy_name}|{er}|{sig}|{kw}"

                elif strategy.strategy_name in ["Parallel CMA-ES", "Parallel MC SVGD", "CMA_ES", "OpenES"]:
                    fname = f"{res_path}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite.pkl"
                    dictkey = f"{strategy.strategy_name}|{er}|{sig}"
                
                # Load
                try:
                    with open(fname, "rb") as handler:
                        data = pickle.load(handler)
                        plot_data[dictkey] = {}
                        for key in data.keys():
                            plot_data[dictkey][key] = jnp.array(data[key])
                
                except Exception as e:
                    #print(e)
                    #print(f"Not found!: {strategy.strategy_name}{er}{sigma}")
                    pass

print(f"Loaded {len(plot_data)} runs")

# +
# Iterate solutions and get the best one
val_interval = int(num_generations // 5)
best = {s.strategy_name: {"val": jnp.ones(val_interval) * -10_000} for s in strategies}

for resname in plot_data.keys():
    strategy_name = resname.split("|")[0]
    means = jnp.array(plot_data[resname]["fitness_max"])
    means_mean = jnp.mean(means, axis=0)
    means_std = jnp.std(means, axis=0) * 1.96 / jnp.sqrt(means.shape[0])

    #if means_mean[-1] >= 85:
    #    plt.plot(means_mean, label=resname + " " + str(means_mean[-1]))#, color=cmapper(strategy_name)["color"])
    
    idx = -1
    if jnp.sum(means_mean[idx:]) >= jnp.sum(best[strategy_name]["val"][idx:]):
        best[strategy_name]["val"] = means_mean
        best[strategy_name]["name"] = strategy_name
        best[strategy_name]["std"] = means_std
        best[strategy_name]["cfg"] = resname

# Plot performance on val set
t = jnp.arange(means.shape[1]) * val_interval * npop * subpopsize / 1000
plt.figure(figsize=(6, 4.5))
for b in best.keys():
    res = best[b]
    if res["name"]:
        style = cmapper(res["name"])
        plt.plot(t, res["val"], label=res["name"], color=style["color"], linestyle=style["linestyle"], linewidth=2)
        plt.fill_between(t, res["val"]-res["std"], res["val"]+res["std"], alpha=.2, color=style["color"])

plt.ylabel("Mean Return")
plt.xlabel(r"Rollouts ($\times 10^3$)")
plt.legend();

print("Best results at:")
for s in best.keys():
    print(best[s]["cfg"], best[s]["val"][-1])
# -

kernel_vals, elite_ratios, sigmas, adam_lrs, sigmas_gf


