# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# +
# Choose physical device here
# %matplotlib inline
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform
# %env CUDA_VISIBLE_DEVICES=0

import os
import pickle
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as onp
import scienceplots
import seaborn as sns

from sves.kernels import RBF
from sves.strategies import OG_SVGD, GF_SVGD, SV_CMA_BB, MC_SVGD, ParallelCMAES
from sves.benchmarks.sampling_eval import run_cfg_cma, run_cfg_oes, run_cfg_og, run_cfg_gf, eval_sampling
from sves.benchmarks.utils.eval_utils import cmapper, compute_metrics
from sves.benchmarks.synthetic_benchmarks import DoubleBanana

plt.style.use(['science','no-latex'])
plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

# +
# Create results path
res_path = "data/paper_res/banana/"
os.makedirs(res_path, exist_ok=True)

# Plotting format options
plt.style.use(['science','no-latex'])
plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

devices = jax.devices("cuda")
devices
# -

# ## Paper plot
#
# The following code will produce the plot in the paper which is featured in Figure 2.

# +
rng = jax.random.PRNGKey(2)

rng, rng_init, rng_sample = jax.random.split(rng, 3)
bench = DoubleBanana(rng_init)

objective_fn, score_fn = bench.get_objective_derivative()
ub = bench.upper_bounds
lb = bench.lower_bounds
npop = 100
subpopsize = 4
nrep = 10  # 10 rep for eval
num_generations = 1_000
gt_samples = bench.sample(jax.random.PRNGKey(0), 256)

# Define values for grid search
kernel_vals = jnp.linspace(1e-4, .1, 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 2., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
elite_ratios = jnp.linspace(.15, .5, 10)

# Define strategies with hyperparams from paper
strategies = [
    OG_SVGD(npop * subpopsize, RBF(kernel_vals[0]), num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=adam_lrs[-1]),
    GF_SVGD(npop * subpopsize, RBF(kernel_vals[1]), sigma_init=sigmas_gf[-5], num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=adam_lrs[0]),
    SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[1]), num_iters=num_generations, num_dims=bench.dim, elite_ratio=elite_ratios[-1], sigma_init=sigmas[-1]),
    MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[0]), num_iters=num_generations, num_dims=bench.dim, sigma_init=sigmas[2], lrate_init=adam_lrs[0]),
]
# -

# First we run the experiments.

for strategy in strategies:
    seeds = jax.random.split(rng, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    fname = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    if strategy.strategy_name == "OG_SVGD":
        fname += f"{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "wb") as handler:
        pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# +
# Load res & compute metrics
paper_plot = {}
for strategy in strategies:
    if strategy.strategy_name == "OG_SVGD":
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    else:
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"

    with open(fn, "rb") as handler:
        data = pickle.load(handler)
        gt_samples = bench.sample(jax.random.PRNGKey(0), 256)
        particles = data
        paper_plot[strategy.strategy_name] = compute_metrics(particles, gt_samples)
        
# Plot res
plt.figure(figsize=(6, 4.5))
for strategy in paper_plot.keys():
    mmds, mmds_std = paper_plot[strategy]["mmds"]
    mmd_mean = mmds
    mmd_std = mmds_std * 1.96 / jnp.sqrt(nrep)

    # Plot
    t = jnp.arange(mmds.shape[0])
    sname = strategy.split("|")[0]
    style = cmapper(sname)
    plt.plot(t, mmd_mean, label=style['name'], color=style['color'], linestyle=style['linestyle'])
    plt.fill_between(t, mmd_mean-mmd_std, mmd_mean+mmd_std, alpha=.2, color=style['color'])

plt.ylabel("log10 MMD")
plt.xlabel("Num. iter.")
plt.tight_layout()
# -



# ## Supplementary: Run grid search
#
# First define setup. We will use 5 seeds as specified in the paper. The test run will be across 10.

# +
rng = jax.random.PRNGKey(2)

rng, rng_init, rng_sample = jax.random.split(rng, 3)
bench = DoubleBanana(rng_init)

objective_fn, score_fn = bench.get_objective_derivative()
ub = bench.upper_bounds
lb = bench.lower_bounds
npop = 100
subpopsize = 4
nrep = 5
num_generations = 1_000
gt_samples = bench.sample(jax.random.PRNGKey(0), 256)

# Define values for grid search
kernel_vals = jnp.linspace(1e-4, .1, 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 2., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
elite_ratios = jnp.linspace(.15, .5, 10)

RBF().median_heuristic(gt_samples)
# -

# #### CMA-based

# +
# Set up
eval_cma = lambda er, sig, kw: run_cfg_cma(rng, npop, subpopsize, er, sig, kw, nrep, num_generations, bench)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_cma(er, sig, kw))(sigmas)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas):
            results = rj
            with open(f"{res_path}{bench.name}BB_SVGD_ES{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### OpenES based

# +
# Set up
eval_oes = lambda lr, sig, kw: run_cfg_oes(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, bench) 

for er in adam_lrs: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_oes(er, sig, kw))(sigmas)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas):
            results = rj
            #print(results.shape)
            with open(f"{res_path}{bench.name}MC SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### Adam-based
#
# First gradient-based SVGD then GF-SVGD.

# +
# Set up
eval_og = lambda lr, kw: run_cfg_og(rng, npop, subpopsize, lr, kw, nrep, num_generations, bench) 
 
res = jax.vmap(lambda kw: 
    jax.vmap(lambda lr: eval_og(lr, kw))(adam_lrs)
)(kernel_vals)

for ri, kw in zip(res, kernel_vals):
    for rj, lr in zip(ri, adam_lrs):
        results = rj
        with open(f"{res_path}{bench.name}OG_SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{lr}elite{kw}kernel.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                

# +
# Set up
eval_gf = lambda lr, sig, kw: run_cfg_gf(rng, npop, subpopsize, lr=lr, sig=sig, kw=kw, nrep=nrep, num_generations=num_generations, bench=bench) 

for er in adam_lrs: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_gf(er, sig, kw))(sigmas_gf)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas_gf):
            results = rj
            #print(results.shape)
            with open(f"{res_path}{bench.name}GF_SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# ### Plot
#
# Given all data is there this code plots the best configuration per method.

# +
strategies = [
    OG_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    #GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    #SV_CMA_BB(1, 1, kernel=RBF(), num_iters=1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    #MC_SVGD(1, 2, kernel=RBF(), num_iters=1, num_dims=1, sigma_init=1),
]

plot_data = {}
file_template = res_path + "{bench_name}{strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl"

for strategy in strategies:
    strategy_name = strategy.strategy_name
    common_params = {
        'bench_name': bench.name,
        'strategy_name': strategy_name,
        'nrep': nrep,
        'num_generations': num_generations,
        'npop': npop,
        'subpopsize': subpopsize
    }

    for er in jnp.concatenate([elite_ratios, adam_lrs]):
        for kw in kernel_vals:
            try:
                if strategy_name == "OG_SVGD":
                    file_template = file_template.replace("sigma", "")
                    file_path = file_template.format(sig="", er=er, kw=kw, **common_params)
                    with open(file_path, "rb") as handler:
                        data = pickle.load(handler)
                        # Subsampling
                        plot_data[f"{strategy_name}|{er}|{kw}"] = compute_metrics(data[:, :, ::subpopsize], gt_samples)

                elif strategy_name == "GF_SVGD":
                    for sig in sigmas_gf:
                        file_path = file_template.format(sig=sig, er=er, kw=kw, **common_params)
                        with open(file_path, "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy_name}|{er}|{sig}|{kw}"] = compute_metrics(data[:, :, ::subpopsize], gt_samples)

                else:
                    for sig in sigmas:
                        file_path = file_template.format(sig=sig, er=er, kw=kw, **common_params)
                        with open(file_path, "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy_name}|{er}|{sig}|{kw}"] = compute_metrics(data, gt_samples)

            except Exception as e:
                #print(f"Error processing {strategy_name} with er={er}, kw={kw}, sig={sig}: {e}")
                continue

#len(plot_data)

# +
# Iterate solutions and get the best one
best = {s.strategy_name: {"val": [100]} for s in strategies}
key = "mmds"
plt.figure(figsize=(6, 4.5))
for strategy in plot_data.keys():
    strategy_name = strategy.split("|")[0]
    mmds, mmds_std = plot_data[strategy][key]
    mmd_mean = mmds
    mmd_std = mmds_std * 1.96 / jnp.sqrt(nrep)

    if mmd_mean[-1] < best[strategy_name]["val"][-1]:
        best[strategy_name]["val"] = mmd_mean
        best[strategy_name]["name"] = strategy_name
        best[strategy_name]["std"] = mmd_std
        best[strategy_name]["cfg"] = strategy

# Plot
t = jnp.arange(mmds.shape[0])
for b in best.keys():
    res = best[b]
    if res["name"]:
        sname = res["name"]
        style = cmapper(sname)
        plt.plot(t, res["val"], label=style['name'], color=style['color'], linestyle=style['linestyle'])
        plt.fill_between(t, res["val"]-res["std"], res["val"]+res["std"], alpha=.2, color=style['color'])

plt.ylabel("log10 MMD")
plt.xlabel("N. iter.")
plt.legend(fontsize=8)

print("Best results at:")
for s in best.keys():
    print(best[s]["cfg"])
# -

# ## Particle scaling
#

nrep = 10
subpops = jnp.array([4, 8, 16])
n_particles = 2 ** onp.arange(4, 10)   # Needs onp
n_particles

sig = sigmas[-1]
for subs in subpops:
    for npart in n_particles:
        ergo = elite_ratios[-1] * (4 / subs)
        kerva = kernel_vals[1] * (100 / npart)
        print(ergo, kerva)
        res = run_cfg_cma(
            rng, 
            npart, 
            subs, 
            er=ergo, 
            sig=sig, 
            kw=kerva, 
            nrep=nrep, 
            num_generations=num_generations, 
            bench=bench
        )
        with open(f"{res_path}{bench.name}BB_SVGD_ES{nrep}rep{num_generations}iter{npart}pop{subs}members{sig}sigma{ergo}elite{kerva}kernel.pkl", "wb") as handler:
                pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)
            

sig = sigmas[2]
lr = adam_lrs[0]
for subs in subpops:
    for npart in n_particles:
        kerva = kernel_vals[0] * (100 / npart)
        res = run_cfg_oes(rng, npart, int(subs), lr, sig, kerva, nrep, num_generations, bench)
        with open(f"{res_path}{bench.name}MC SVGD{nrep}rep{num_generations}iter{npart}pop{subs}members{sig}sigma{lr}elite{kerva}kernel.pkl", "wb") as handler:
            pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)
            

lr = adam_lrs[-1]
for npart in n_particles:
    kerva = kernel_vals[0] * (100 / npart)
    res = run_cfg_og(rng, npart, 1, lr, kerva, nrep, num_generations, bench)    
    with open(f"{res_path}{bench.name}OG_SVGD{nrep}rep{num_generations}iter{npart}pop{1}members{lr}elite{kerva}kernel.pkl", "wb") as handler:
        pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)

sig = sigmas_gf[-5]
lr = adam_lrs[0]
for npart in n_particles:
    kerva = kernel_vals[1] * (100 / npart)
    res = run_cfg_gf(rng, npart, 1, sig, lr, kerva, nrep, num_generations, bench)  
    with open(f"{res_path}{bench.name}GF_SVGD{nrep}rep{num_generations}iter{npart}pop{1}members{sig}sigma{lr}elite{kerva}kernel.pkl", "wb") as handler:
        pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)


