# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Gaussian Mixture Sampling Evaluation
#
# This notebook contains code to reproduce our evaluation on the Gaussian Mixture sampling problem.

# +
import os
# %matplotlib inline
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform

import pickle
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as onp
import scienceplots
import seaborn as sns

from sves.kernels import RBF
from sves.strategies import OG_SVGD, GF_SVGD, SV_CMA_BB, MC_SVGD, ParallelCMAES
from sves.benchmarks.sampling_eval import run_cfg_cma, run_cfg_oes, run_cfg_og, run_cfg_gf, eval_sampling
from sves.benchmarks.utils.eval_utils import cmapper, compute_metrics
from sves.benchmarks.synthetic_benchmarks import GMM

# +
# Fake a desitributed session to really only use a specific gpu
jax.distributed.initialize(coordinator_address="127.0.0.1:1215",
                           num_processes=1,
                           process_id=0,
                           local_device_ids=[1])
devices = jax.devices("cuda")

# Create results path
res_path = "data/paper_res/2dgauss/"
os.makedirs(res_path, exist_ok=True)

# Plotting format options
plt.style.use(['science','no-latex'])
plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

devices
# -

res_path = "data/paper_res2/2dgauss/"


# ## Paper plot
#
# The following code will produce the plot in the paper which is featured in Figure 2.

# +
rng = jax.random.PRNGKey(2) # seed 2 produces nice looking GMM density; but any works really

rng, rng_init, rng_sample = jax.random.split(rng, 3)
bench = GMM(rng_init, dim=2, lb=-6., ub=6., n_modes=4)

objective_fn, score_fn = bench.get_objective_derivative()
ub = bench.upper_bounds
lb = bench.lower_bounds
npop = 100
subpopsize = 4
nrep = 10   # 10 runs for eval
num_generations = 1_000
gt_samples = bench.sample(jax.random.PRNGKey(0), 256)

# Define values for grid search
kernel_vals = jnp.linspace(1e-3, 1., 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.1, 6., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
elite_ratios = jnp.linspace(.15, .5, 10)

# Define strategies with hyperparameters of paper
strategies = [
    #OG_SVGD(npop * subpopsize, RBF(kernel_vals[2]), num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=adam_lrs[3]),
    #GF_SVGD(npop * subpopsize, RBF(kernel_vals[-2]), sigma_init=sigmas_gf[4], num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=adam_lrs[-1]),
    SV_CMA_BB(npop, subpopsize, kernel=RBF(kernel_vals[-2]), num_iters=num_generations, num_dims=bench.dim, elite_ratio=elite_ratios[-1], sigma_init=sigmas[-1]),
    #MC_SVGD(npop, subpopsize, kernel=RBF(kernel_vals[0]), num_iters=num_generations, num_dims=bench.dim, sigma_init=sigmas[1], lrate_init=adam_lrs[-2]),
]
# -

# First we run the experiments.

nrep = 10
for strategy in strategies:
    seeds = jax.random.split(rng, nrep)
    results = jax.vmap(
        lambda seed: eval_sampling(seed, strategy, bench, num_generations, cb_freq=500)
    )(seeds)
    fname = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members"
    if strategy.strategy_name == "OG_SVGD":
        fname += f"{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fname += f"{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    else:
        fname += f"{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"
    with open(fname, "wb") as handler:
        pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# Now we plot.

# +
# Load res & compute metrics
paper_plot = {}
for strategy in strategies:
    if strategy.strategy_name == "OG_SVGD":
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    elif strategy.strategy_name in ["GF_SVGD", "MC SVGD"]:
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
    else:
        fn = f"{res_path}{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"

    with open(fn, "rb") as handler:
        data = pickle.load(handler)
        gt_samples = bench.sample(jax.random.PRNGKey(0), 256)
        particles = data
        paper_plot[strategy.strategy_name] = compute_metrics(particles, gt_samples)

# Plot res
plt.figure(figsize=(6, 4.5))
for strategy in paper_plot.keys():
    mmds, mmds_std = paper_plot[strategy]["mmds"]
    mmd_mean = mmds
    mmd_std = mmds_std * 1.96 / jnp.sqrt(nrep)

    # Plot
    t = jnp.arange(mmds.shape[0])
    sname = strategy.split("|")[0]
    style = cmapper(sname)
    plt.plot(t, mmd_mean, label=style['name'], color=style['color'], linestyle=style['linestyle'])
    plt.fill_between(t, mmd_mean-mmd_std, mmd_mean+mmd_std, alpha=.2, color=style['color'])

plt.ylim(-4.5, 1.)
plt.ylabel("log10 MMD")
plt.xlabel("Num. iter.")
leg = plt.legend()
leg.get_frame().set_linewidth(0.0)
plt.tight_layout()
#plt.savefig(f"data/paper_res2/imgs/gmm/{bench.name}_mmd_convergence3.svg", format="svg")
# -



# ## Supplementary: Run grid search
#
# First define setup. We will use 5 seeds as specified in the paper. The test run will be across 10.

# +
rng = jax.random.PRNGKey(2) # seed 2 produces nice looking GMM density; but any works really; just make sure its consistent with the paper plot!

rng, rng_init, rng_sample = jax.random.split(rng, 3)
bench = GMM(rng_init, dim=2, lb=-6., ub=6., n_modes=4)

objective_fn, score_fn = bench.get_objective_derivative()
ub = bench.upper_bounds
lb = bench.lower_bounds
npop = 100
subpopsize = 4
nrep = 5   # 5 runs for hyperparameter tuning
num_generations = 1_000
gt_samples = bench.sample(jax.random.PRNGKey(0), 256)

# Define values for grid search
kernel_vals = jnp.linspace(1e-3, 1., 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.1, 6., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
elite_ratios = jnp.linspace(.15, .5, 10)

RBF().median_heuristic(gt_samples)
# -

# #### CMA-based

# +
# Set up
eval_cma = lambda er, sig, kw: run_cfg_cma(rng, npop, subpopsize, er, sig, kw, nrep, num_generations, bench)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_cma(er, sig, kw))(sigmas)
    )(kernel_vals)
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas):
            results = rj
            with open(f"{res_path}{bench.name}BB_SVGD_ES{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### OpenES based

# +
# Set up
eval_oes = lambda lr, sig, kw: run_cfg_oes(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, bench) 

for er in adam_lrs: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_oes(er, sig, kw))(sigmas)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas):
            results = rj
            #print(results.shape)
            with open(f"{res_path}{bench.name}MC SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# #### Adam-based
#
# First OG SVGD then GF-SVGD.

# +
# Set up
eval_og = lambda lr, kw: run_cfg_og(rng, npop, subpopsize, lr, kw, nrep, num_generations, bench) 
 
res = jax.vmap(lambda kw: 
    jax.vmap(lambda lr: eval_og(lr, kw))(adam_lrs)
)(kernel_vals)

for ri, kw in zip(res, kernel_vals):
    for rj, lr in zip(ri, adam_lrs):
        results = rj
        with open(f"{res_path}{bench.name}OG_SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{lr}elite{kw}kernel.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                

# +
# Set up
eval_gf = lambda lr, sig, kw: run_cfg_gf(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, bench) 

for er in adam_lrs: 
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_gf(er, sig, kw))(sigmas_gf)
    )(kernel_vals)
    
    for ri, kw in zip(res, kernel_vals):
        for rj, sig in zip(ri, sigmas_gf):
            results = rj
            with open(f"{res_path}{bench.name}GF_SVGD{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# ### Plot
#
# Given all data is there this code plots the best configuration per method.

# +
strategies = [
    OG_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    #GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    #SV_CMA_BB(1, 1, kernel=RBF(), num_iters=1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    #MC_SVGD(1, 2, kernel=RBF(), num_iters=1, num_dims=1, sigma_init=1),
]

plot_data = {}
file_template = res_path + "{bench_name}{strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl"

for strategy in strategies:
    strategy_name = strategy.strategy_name
    common_params = {
        'bench_name': bench.name,
        'strategy_name': strategy_name,
        'nrep': nrep,
        'num_generations': num_generations,
        'npop': npop,
        'subpopsize': subpopsize
    }

    for er in jnp.concatenate([elite_ratios, adam_lrs]):
        for kw in kernel_vals:
            try:
                if strategy_name == "OG_SVGD":
                    file_template = file_template.replace("sigma", "")
                    file_path = file_template.format(sig="", er=er, kw=kw, **common_params)
                    with open(file_path, "rb") as handler:
                        data = pickle.load(handler)
                        # Subsampling for non-ES methods because otherwise we will have too much in memory; only done for model selection
                        plot_data[f"{strategy_name}|{er}|{kw}"] = compute_metrics(data[:, :, ::subpopsize], gt_samples)

                elif strategy_name == "GF_SVGD":
                    for sig in sigmas_gf:
                        file_path = file_template.format(sig=sig, er=er, kw=kw, **common_params)
                        with open(file_path, "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy_name}|{er}|{sig}|{kw}"] = compute_metrics(data[:, :, ::subpopsize], gt_samples)

                else:
                    for sig in sigmas:
                        file_path = file_template.format(sig=sig, er=er, kw=kw, **common_params)
                        with open(file_path, "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy_name}|{er}|{sig}|{kw}"] = compute_metrics(data, gt_samples)

            except Exception as e:
                continue

# +
# Iterate solutions and get the best one
best = {s.strategy_name: {"val": [100]} for s in strategies}
key = "mmds"

plt.figure(figsize=(6, 4.5))
for strategy in plot_data.keys():
    strategy_name = strategy.split("|")[0]
    mmds, mmds_std = plot_data[strategy][key]
    mmd_mean = mmds
    mmd_std = mmds_std * 1.96 / jnp.sqrt(nrep)

    if mmd_mean[-1] < best[strategy_name]["val"][-1]:
        best[strategy_name]["val"] = mmd_mean
        best[strategy_name]["name"] = strategy_name
        best[strategy_name]["std"] = mmd_std
        best[strategy_name]["cfg"] = strategy

# Plot
t = jnp.arange(mmds.shape[0])
for b in best.keys():
    res = best[b]
    if res["name"]:
        sname = res["name"]
        style = cmapper(sname)
        plt.plot(t, res["val"], label=style['name'], color=style['color'], linestyle=style['linestyle'])
        plt.fill_between(t, res["val"]-res["std"], res["val"]+res["std"], alpha=.2, color=style['color'])

plt.ylabel("log10 MMD")
plt.xlabel("N. iter.")
#plt.ylim(-2.5, .1)
plt.legend(fontsize=8)

print("Best results at:")
for s in best.keys():
    print(best[s]["cfg"])
# -

kernel_vals, sigmas, elite_ratios, adam_lrs, sigmas_gf

# ## Particle scaling
#
# For this we again run the methods.
# To account for the kernel being optimized for a subpop size of 4 and 100 particles, we adjust it by linear scaling.
# This slightly improves the results across all methods.

nrep = 10
subpops = jnp.array([4, 8, 16])
n_particles = 2 ** onp.arange(4, 10)   # Needs onp
n_particles, subpops

sig = sigmas[-1]
for subs in subpops:
    for npart in n_particles:
        ergo = elite_ratios[-1] * (4 / subs)
        kerva = kernel_vals[-2] * (100 / npart)
        print(ergo, kerva)
        res = run_cfg_cma(
            rng, 
            npart, 
            subs, 
            er=ergo, 
            sig=sig, 
            kw=kerva, 
            nrep=nrep, 
            num_generations=num_generations, 
            bench=bench
        )
        with open(f"{res_path}{bench.name}BB_SVGD_ES{nrep}rep{num_generations}iter{npart}pop{subs}members{sig}sigma{ergo}elite{kerva}kernel.pkl", "wb") as handler:
                pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)
            

sig = sigmas[1]
lr = adam_lrs[-2]
for subs in subpops:
    for npart in n_particles:
        kerva = kernel_vals[0] * (100 / npart)
        res = run_cfg_oes(rng, npart, int(subs), lr, sig, kerva, nrep, num_generations, bench)
        with open(f"{res_path}{bench.name}MC SVGD{nrep}rep{num_generations}iter{npart}pop{subs}members{sig}sigma{lr}elite{kerva}kernel.pkl", "wb") as handler:
            pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)
            

lr = adam_lrs[3]
for npart in n_particles:
    kerva = kernel_vals[2] * (100 / npart)
    res = run_cfg_og(rng, npart, 1, lr, kerva, nrep, num_generations, bench)    
    with open(f"{res_path}{bench.name}OG_SVGD{nrep}rep{num_generations}iter{npart}pop{1}members{lr}elite{kerva}kernel.pkl", "wb") as handler:
        pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)

sig = sigmas_gf[4]
lr = adam_lrs[-1]
for npart in n_particles:
    kerva = kernel_vals[-2] * (100 / npart)
    res = run_cfg_gf(rng, npart, 1, sig, lr, kerva, nrep, num_generations, bench)  
    with open(f"{res_path}{bench.name}GF_SVGD{nrep}rep{num_generations}iter{npart}pop{1}members{sig}sigma{lr}elite{kerva}kernel.pkl", "wb") as handler:
        pickle.dump(res, handler, protocol=pickle.HIGHEST_PROTOCOL)


