# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Motion Planning Evaluation
#
# This notebook contains code to reproduce our evaluation on the motion planning problem.
# The benchmark is inspired by the excellent paper *Path Signatures for Diversity in Probabilistic Trajectory Optimisation* (Barcelos et al., 2024). Thir code can be found [here](https://github.com/lubaroli/sigsvgd/tree/master).
#

# +
# Choose physical device here
# %matplotlib inline
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform
# %env CUDA_VISIBLE_DEVICES=0
                                
import os
import pickle
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as onp
import scienceplots
import seaborn as sns

from sves.kernels import RBF
from sves.strategies import OG_SVGD, GF_SVGD, SV_CMA_BB, MC_SVGD
from sves.benchmarks.utils.eval_utils import cmapper, compute_metrics
from sves.benchmarks.sampling_eval import run_cfg_cma, run_cfg_oes, run_cfg_og, run_cfg_gf, eval_sampling
from sves.benchmarks.simple_motion_plan import RamosPaper

# +
# Create results path
res_path = "data/paper_res/ramos/"
os.makedirs(res_path, exist_ok=True)

# Plotting format options
plt.style.use(['science','no-latex'])
plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

devices = jax.devices("cuda")
devices
# -

# ## Paper plot
#
# The following code will produce the plot in the paper which is featured in Figure 2.

# +
# Define problem
rng = jax.random.PRNGKey(2)
rng, rng_init, rng_sample = jax.random.split(rng, 3)
bench = RamosPaper(n_via=5)
objective_fn, score_fn = bench.get_objective_derivative()
ub = bench.upper_bounds
lb = bench.lower_bounds
npop = 100
subpopsize = 4
nrep = 10  # 10 seeds for eval
num_generations = 1_000
gt_samples = bench.sample(jax.random.PRNGKey(0), 256)

# Define values for grid search
kernel_vals = jnp.linspace(1e-2, 3., 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 4., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
elite_ratios = jnp.linspace(.15, .5, 10)

# Define strategies with hyperparameters of paper
strategies = [
    OG_SVGD(npop * subpopsize, RBF(kernel_vals[0]), num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=adam_lrs[2]),
    GF_SVGD(npop * subpopsize, RBF(kernel_vals[2]), sigma_init=sigmas_gf[-4], num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=adam_lrs[0]),
    SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[0]), num_iters=num_generations, num_dims=bench.dim, elite_ratio=elite_ratios[-1], sigma_init=sigmas[1]),
    MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[0]), num_iters=num_generations, num_dims=bench.dim, sigma_init=sigmas[1], lrate_init=adam_lrs[3]),
]
# -

# First we run the experiments.

for strategy in strategies:
    seeds = jax.random.split(rng, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    fname = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    if strategy.strategy_name == "OG_SVGD":
        fname += f"{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "wb") as handler:
        pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# Plot results.

# +
# Load and compuet metrics
paper_plot = {}
for strategy in strategies:
    if strategy.strategy_name == "OG_SVGD":
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    else:
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"

    with open(fn, "rb") as handler:
        data = pickle.load(handler)
        gt_samples = bench.sample(jax.random.PRNGKey(0), 256)
        particles = data
        paper_plot[strategy.strategy_name] = compute_metrics(particles, gt_samples)

# Plot res
plt.figure(figsize=(6, 4.5))
for strategy in paper_plot.keys():
    mmds, mmds_std = paper_plot[strategy]["mmds"]
    mmd_mean = mmds
    mmd_std = mmds_std * 1.96 / jnp.sqrt(nrep)

    # Plot
    t = jnp.arange(mmds.shape[0])
    sname = strategy.split("|")[0]
    style = cmapper(sname)
    plt.plot(t, mmd_mean, label=style['name'], color=style['color'], linestyle=style['linestyle'])
    plt.fill_between(t, mmd_mean-mmd_std, mmd_mean+mmd_std, alpha=.2, color=style['color'])

plt.ylabel("log10 MMD")
plt.xlabel("Num. iter.")
# -



# ## Run grid search
#
# The following code can be used to run the grid search for hyperparameter tuning.

# +
rng = jax.random.PRNGKey(2)

rng, rng_init, rng_sample = jax.random.split(rng, 3)
bench = RamosPaper(n_via=5)

objective_fn, score_fn = bench.get_objective_derivative()
ub = bench.upper_bounds
lb = bench.lower_bounds
npop = 100
subpopsize = 4
nrep = 5  # only 5 rep for faster grid search
num_generations = 1_000
gt_samples = bench.sample(jax.random.PRNGKey(0), 256)

# Define values for grid search
kernel_vals = jnp.linspace(1e-2, 3., 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 4., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
elite_ratios = jnp.linspace(.15, .5, 10)

RBF().median_heuristic(gt_samples)
# -

# #### CMA-based

# +
# Set up
eval_cma = lambda er, sig, kw: run_cfg_cma(rng, npop, subpopsize, er, sig, kw, nrep, num_generations, bench)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_cma(er, sig, kw))(sigmas)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas):
            results = rj
            with open(f"{res_path}{bench.name}BB_SVGD_ES{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### OpenES based

# +
# Set up
eval_oes = lambda lr, sig, kw: run_cfg_oes(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, bench) 

for er in adam_lrs: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_oes(er, sig, kw))(sigmas)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas):
            results = rj
            #print(results.shape)
            with open(f"{res_path}{bench.name}MC SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### Adam-based

# +
# Set up
eval_og = lambda lr, kw: run_cfg_og(rng, npop, subpopsize, lr, kw, nrep, num_generations, bench) 
 
res = jax.vmap(lambda kw: 
    jax.vmap(lambda lr: eval_og(lr, kw))(adam_lrs)
)(kernel_vals)

for ri, kw in zip(res, kernel_vals):
    for rj, lr in zip(ri, adam_lrs):
        results = rj
        with open(f"{res_path}{bench.name}OG_SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{lr}elite{kw}kernel.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                

# +
# Set up
eval_gf = lambda lr, sig, kw: run_cfg_gf(rng, npop, subpopsize, lr=lr, sig=sig, kw=kw, nrep=nrep, num_generations=num_generations, bench=bench) 

for sig in sigmas_gf: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda lr: eval_gf(lr, sig, kw))(adam_lrs)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, lr in zip(ri, adam_lrs):
            #print(results.shape)
            with open(f"{res_path}{bench.name}GF_SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(rj, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### Inspect results

# +
# Dummy classes to get name etc; all defined classes here will be returned in hyperparam search
strategies = [
    OG_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    #GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    #SV_CMA_BB(1, 1, kernel=RBF(), num_iters=1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    #MC_SVGD(1, 2, kernel=RBF(), num_iters=1, num_dims=1, sigma_init=1),
]

# Load data
plot_data = {}
file_template = res_path + "{bench_name}{strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl"

for strategy in strategies:
    strategy_name = strategy.strategy_name
    common_params = {
        'bench_name': bench.name,
        'strategy_name': strategy_name,
        'nrep': nrep,
        'num_generations': num_generations,
        'npop': npop,
        'subpopsize': subpopsize
    }

    for er in jnp.concatenate([elite_ratios, adam_lrs]):
        for kw in kernel_vals:
            try:
                if strategy_name == "OG_SVGD":
                    file_template = file_template.replace("sigma", "")
                    file_path = file_template.format(sig="", er=er, kw=kw, **common_params)
                    with open(file_path, "rb") as handler:
                        data = pickle.load(handler)
                        # Subsampling
                        plot_data[f"{strategy_name}|{er}|{kw}"] = compute_metrics(data[:, :, ::subpopsize], gt_samples)

                elif strategy_name == "GF_SVGD":
                    for sig in sigmas_gf:
                        file_path = file_template.format(sig=sig, er=er, kw=kw, **common_params)
                        with open(file_path, "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy_name}|{er}|{sig}|{kw}"] = compute_metrics(data[:, :, ::subpopsize], gt_samples)

                else:
                    for sig in sigmas:
                        file_path = file_template.format(sig=sig, er=er, kw=kw, **common_params)
                        with open(file_path, "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy_name}|{er}|{sig}|{kw}"] = compute_metrics(data, gt_samples)

            except Exception as e:
                #print(f"Error processing {strategy_name} with er={er}, kw={kw}, sig={sig}: {e}")
                continue

print(f"Loaded {len(plot_data)} episodes")

# +
# Iterate solutions and get the best one
best = {s.strategy_name: {"val": [100]} for s in strategies}
key = "mmds"
plt.figure(figsize=(6, 4.5))
for strategy in plot_data.keys():
    strategy_name = strategy.split("|")[0]
    mmds, mmds_std = plot_data[strategy][key]
    mmd_mean = mmds
    mmd_std = mmds_std * 1.96 / jnp.sqrt(nrep)

    if mmd_mean[-1] < best[strategy_name]["val"][-1]:
        best[strategy_name]["val"] = mmd_mean
        best[strategy_name]["name"] = strategy_name
        best[strategy_name]["std"] = mmd_std
        best[strategy_name]["cfg"] = strategy

# Plot
t = jnp.arange(mmds.shape[0])
for b in best.keys():
    res = best[b]
    if res["name"]:
        sname = res["name"]
        style = cmapper(sname)
        plt.plot(t, res["val"], label=style['name'], color=style['color'], linestyle=style['linestyle'])
        plt.fill_between(t, res["val"]-res["std"], res["val"]+res["std"], alpha=.2, color=style['color'])

plt.ylabel("log10 MMD")
plt.xlabel("N. iter.")
plt.legend(fontsize=8)

print("Best results at:")
for s in best.keys():
    print(best[s]["cfg"])
# -

kernel_vals, sigmas, elite_ratios, adam_lrs, sigmas_gf

# ## Particle scaling
#
# Additional code to run the benchmark at multiple numbers of particles. The results will be used in `meta_plot.ipynb` to create Fig. 5.
# To account for the kernel being optimized for a subpop size of 4 and 100 particles, we adjust it by linear scaling.
# This slightly improves the results across all methods.

nrep = 10
subpops = jnp.array([4, 8, 16])
n_particles = 2 ** onp.arange(4, 10)   # Needs onp
n_particles, subpops, sigmas[1], kernel_vals[-4]

sig = sigmas[1]
for subs in subpops:
    for npart in n_particles:
        ergo = elite_ratios[-1] * (4 / subs)
        kerva = kernel_vals[0] * (100 / npart)
        res = run_cfg_cma(
            rng, 
            npart, 
            subs, 
            er=ergo, 
            sig=sig, 
            kw=kerva, 
            nrep=nrep, 
            num_generations=num_generations, 
            bench=bench
        )
        with open(f"{res_path}{bench.name}BB_SVGD_ES{nrep}rep{num_generations}iter{npart}pop{subs}members{sig}sigma{ergo}elite{kerva}kernel.pkl", "wb") as handler:
                pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)
            

sig = sigmas[1]
lr = adam_lrs[3]
for subs in subpops:
    for npart in n_particles:
        kerva = kernel_vals[0] * (100 / npart)
        res = run_cfg_oes(rng, npart, int(subs), lr, sig, kerva, nrep, num_generations, bench)
        with open(f"{res_path}{bench.name}MC SVGD{nrep}rep{num_generations}iter{npart}pop{subs}members{sig}sigma{lr}elite{kerva}kernel.pkl", "wb") as handler:
            pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)
            

lr = adam_lrs[2]
for npart in n_particles:
    kerva = kernel_vals[0] * (100 / npart)
    res = run_cfg_og(rng, npart, 1, lr, kerva, nrep, num_generations, bench)    
    with open(f"{res_path}{bench.name}OG_SVGD{nrep}rep{num_generations}iter{npart}pop{1}members{lr}elite{kerva}kernel.pkl", "wb") as handler:
        pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)

sig = sigmas_gf[-4]
lr = adam_lrs[0]
for npart in n_particles:
    kerva = kernel_vals[2] * (100 / npart)
    res = run_cfg_gf(rng, npart, 1, sig, lr, kerva, nrep, num_generations, bench)  
    with open(f"{res_path}{bench.name}GF_SVGD{nrep}rep{num_generations}iter{npart}pop{1}members{sig}sigma{lr}elite{kerva}kernel.pkl", "wb") as handler:
        pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)


