# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Particle scaling
#
# This notebook contains the code to reproduce the plots for the particle scaling (Fig. 5).
#
# The code simply loads all runs for the sampling tasks and aggregates them.

# +
# %env CUDA_VISIBLE_DEVICES=0

import pickle
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from flax.struct import dataclass
import scienceplots
import numpy as onp
import seaborn as sns

from sves.strategies import SV_CMA_BB, MC_SVGD, GF_SVGD, OG_SVGD
from sves.kernels import RBF
from sves.benchmarks.simple_motion_plan import RamosPaper
from sves.benchmarks.synthetic_benchmarks import GMM, DoubleBanana
from sves.benchmarks.utils.eval_utils import get_mmd, cmapper
# -

# First we define some general data.

# First plot
plt.style.use(['science','no-latex'])
plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend', fontsize=18)

# +
# Hyperparameters of sampling tasks
nrep = 10
num_generations = 1000
npop = 100
subpopsize = 4
rng = jax.random.PRNGKey(2)
rng, rng_init, rng_samples = jax.random.split(rng, 3)
benchmarks = [GMM(rng_init), DoubleBanana(rng_samples), RamosPaper(n_via=5)]
res_path = "data/paper_res/"

# Ranges that were used during grid search. Necessary because filenames carry exact float representations
gmm_kernel_vals = jnp.linspace(1e-3, 1., 10)
gmm_sigmas = jnp.linspace(.05, .5, 10)
gmm_sigmas_gf = jnp.linspace(.1, 6., 10)
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
gmm_elite_ratios = jnp.linspace(.15, .5, 10)
banana_kernel_vals = jnp.linspace(1e-4, .1, 10)
banana_sigmas = jnp.linspace(.05, .5, 10)
banana_elite_ratios = jnp.linspace(.15, .5, 10)
banana_sigmas_gf = jnp.linspace(.01, 2., 10)
ramos_kernel_vals = jnp.linspace(1e-2, 3., 10)
ramos_sigmas = jnp.linspace(.05, .5, 10)
ramos_elite_ratios = jnp.linspace(.15, .5, 10)
ramos_sigmas_gf = jnp.linspace(.01, 4., 10)


# Configs
@dataclass
class Conf:
    er: float = 1.
    sig: float = 1.
    kw: float = 1.
    lr: float = 1.
    
configs = {
    "GMM": {
        "SV_CMA_BB": Conf(er=gmm_elite_ratios[-1], sig=gmm_sigmas[-1], kw=gmm_kernel_vals[-2]), 
        "MC SVGD": Conf(lr=adam_lrs[-2], sig=gmm_sigmas[1], kw=gmm_kernel_vals[0]), 
        "GF_SVGD": Conf(lr=adam_lrs[-1], sig=gmm_sigmas_gf[4], kw=gmm_kernel_vals[-2]), 
        "OG_SVGD": Conf(lr=adam_lrs[3], kw=gmm_kernel_vals[2]),
        "dir_name": "2dgauss"
    },
    "Double Banana": {
        "SV_CMA_BB": Conf(er=banana_elite_ratios[-1], sig=banana_sigmas[-1], kw=banana_kernel_vals[1]), 
        "MC SVGD": Conf(lr=adam_lrs[0], sig=banana_sigmas[2], kw=banana_kernel_vals[0]), 
        "GF_SVGD": Conf(lr=adam_lrs[0], sig=banana_sigmas_gf[-5], kw=banana_kernel_vals[1]), 
        "OG_SVGD": Conf(lr=adam_lrs[-1], kw=banana_kernel_vals[0]),
        "dir_name": "banana"
    },
    "F_RAMOS": {
        "SV_CMA_BB": Conf(er=ramos_elite_ratios[-1], sig=ramos_sigmas[1], kw=ramos_kernel_vals[0]), 
        "MC SVGD": Conf(lr=adam_lrs[3], sig=ramos_sigmas[1], kw=ramos_kernel_vals[0]), 
        "GF_SVGD": Conf(lr=adam_lrs[0], sig=ramos_sigmas_gf[-4], kw=ramos_kernel_vals[2]), 
        "OG_SVGD": Conf(lr=adam_lrs[2], kw=ramos_kernel_vals[0]),
        "dir_name": "ramos"
    },
}

# Placeholder classes
svcma = lambda er, sig, kw: SV_CMA_BB(1, 1, RBF(kw), elite_ratio=er, sigma_init=sig, num_dims=2)
mcsvgd = lambda lr, sig, kw: MC_SVGD(1, 2, RBF(kw), lrate_init=lr, sigma_init=sig, num_dims=2)
gfsvgd = lambda lr, sig, kw: GF_SVGD(1, RBF(kw), sigma_init=sig, lrate_init=lr, num_dims=2)
ogsvgd = lambda lr, kw: OG_SVGD(1, RBF(kw), lrate_init=lr, num_dims=2)
# -

# ## MMD concergences
#
# First we plot the aggregated MMD convergence across all tasks. This is Fig 2d.

# +
#load data
res = {
    "OG_SVGD": [], 
    "GF_SVGD": [], 
    "BB_SVGD_ES": [], 
    "MC SVGD": [],
}

for bench in benchmarks:
    gt_samples = bench.sample(jax.random.PRNGKey(0), 256)
    strategies = [
        ogsvgd(configs[bench.name]["OG_SVGD"].lr, configs[bench.name]["OG_SVGD"].kw),
        gfsvgd(configs[bench.name]["GF_SVGD"].lr, configs[bench.name]["GF_SVGD"].sig, configs[bench.name]["GF_SVGD"].kw),
        svcma(configs[bench.name]["SV_CMA_BB"].er, configs[bench.name]["SV_CMA_BB"].sig, configs[bench.name]["SV_CMA_BB"].kw),
        mcsvgd(configs[bench.name]["MC SVGD"].lr, configs[bench.name]["MC SVGD"].sig, configs[bench.name]["MC SVGD"].kw),
    ]
    strategy_mmds = []
    for strategy in strategies:
        fname = f"{res_path}{configs[bench.name]['dir_name']}/{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop"
        if strategy.strategy_name == "OG_SVGD":
            fname += f"{subpopsize}members{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
        elif strategy.strategy_name in ["MC SVGD", "GF_SVGD"]:
            fname += f"{subpopsize}members{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth}kernel.pkl"
        else:
            fname += f"{subpopsize}members{strategy.sigma_init}sigma{strategy.elite_ratio}elite{strategy.kernel.bandwidth}kernel.pkl"

        with open(fname, "rb") as handler:
            data = pickle.load(handler)
            data = data
            mmd = get_mmd(data, gt_samples)
            strategy_mmds.append(mmd)
            res[strategy.strategy_name].extend(mmd)
            del data

# -

# Plot.

# +
plt.figure(figsize=(6, 4.5))
for strat in res.keys():
    mmds = jnp.array(res[strat])
    log_10mmd = jnp.log10(jnp.abs(mmds))
    std_mmds = 1.96 * jnp.std(log_10mmd, axis=0) / jnp.sqrt(mmds.shape[0])
    mmd_avg = log_10mmd.mean(axis=0)
    mmd_std_avg = std_mmds.mean(axis=0)
    t = jnp.arange(mmds.shape[-1])
    
    style = cmapper(strat)
    plt.plot(t, mmd_avg, label=style["name"], color=style["color"], linestyle=style["linestyle"])
    plt.fill_between(t, mmd_avg+mmd_std_avg, mmd_avg-mmd_std_avg, alpha=.2, color=style["color"])

plt.ylabel("log10 MMD")
plt.xlabel("Num. iter.")
# -

# ## Scaling plot
#
# Now we plot the plots from Fig 5.
#
# #### Lines plot

subpops = jnp.array([4, 8, 16])
n_particles = 2 ** onp.arange(4, 10)   # Needs onp
n_particles, subpops

# +
res = {
    "OG_SVGD": {"mmds": [], "means": [], "vars": []}, 
    "GF_SVGD": {"mmds": [], "means": [], "vars": []},
    "BB_SVGD_ES": {"mmds": [], "means": [], "vars": []},
    "MC SVGD": {"mmds": [], "means": [], "vars": []},
}

for subpop in subpops:
    for n_part in n_particles:
        for i, bench in enumerate(benchmarks):
            gt_samples = bench.sample(jax.random.PRNGKey(0), 256)
            strategies = [
                ogsvgd(configs[bench.name]["OG_SVGD"].lr, configs[bench.name]["OG_SVGD"].kw),
                gfsvgd(configs[bench.name]["GF_SVGD"].lr, configs[bench.name]["GF_SVGD"].sig, configs[bench.name]["GF_SVGD"].kw),
                svcma(configs[bench.name]["SV_CMA_BB"].er, configs[bench.name]["SV_CMA_BB"].sig, configs[bench.name]["SV_CMA_BB"].kw),
                mcsvgd(configs[bench.name]["MC SVGD"].lr, configs[bench.name]["MC SVGD"].sig, configs[bench.name]["MC SVGD"].kw),
            ]
            for strategy in strategies:
                if i == 0:
                    for stat in res[strategy.strategy_name].keys():
                        res[strategy.strategy_name][stat].append([])
                fname = f"{res_path}{configs[bench.name]['dir_name']}/{bench.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{n_part}pop"
                if strategy.strategy_name == "OG_SVGD":
                    fname += f"{1}members{strategy.lrate_init}elite{strategy.kernel.bandwidth * 100 / n_part}kernel.pkl"
                elif strategy.strategy_name == "GF_SVGD":
                    fname += f"{1}members{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth * 100 / n_part}kernel.pkl"
                elif strategy.strategy_name == "MC SVGD":
                    fname += f"{subpop}members{strategy.sigma_init}sigma{strategy.lrate_init}elite{strategy.kernel.bandwidth * 100 / n_part}kernel.pkl"
                else:
                    fname += f"{subpop}members{strategy.sigma_init}sigma{strategy.elite_ratio * 4 / subpop}elite{strategy.kernel.bandwidth * 100 / n_part}kernel.pkl"
        
                with open(fname, "rb") as handler:
                    data = pickle.load(handler)
                    data = data[:, -1:]
                    mmd = onp.array(get_mmd(data, gt_samples))
                    means = jnp.mean(data, axis=(-2, -1))
                    mean_err = (means - jnp.mean(gt_samples)) ** 2
                    means = jnp.var(data, axis=(-2, -1))
                    var_err = (means - jnp.var(gt_samples)) ** 2 
                    res[strategy.strategy_name]["mmds"][-1].extend(mmd)
                    res[strategy.strategy_name]["means"][-1].extend(mean_err)
                    res[strategy.strategy_name]["vars"][-1].extend(var_err)
                    del data
                

# +
key = "mmds"

# Create subplots with one row and n_strategies columns
plt.figure(figsize=(5, 4.))

# Loop over each strategy and plot the heatmap
for i, strategy in enumerate(res.keys()):
    # Extract the data for the current strategy
    pls = jnp.array(res[strategy][key])
    pls = jnp.log10(jnp.abs(pls))[n_particles.shape[0]:2*n_particles.shape[0]]  # select only 4 as number of particles
    pls_std = jnp.std(pls, axis=1).flatten() * 1.96 / jnp.sqrt(pls.shape[1])
    pls = jnp.mean(pls, axis=1).flatten()
    
    # Customize ticks and title
    style = cmapper(strategy)
    t = n_particles
    plt.plot(t, pls, label=style['name'], color=style['color'], linestyle=style["linestyle"], marker="o")
    plt.fill_between(t, pls-pls_std, pls+pls_std, alpha=.2, color=style['color'])

plt.ylabel("log10 MMD" if key == "mmds" else "log10 MSE")
plt.xlabel("Pop. size")
plt.xscale("log")
plt.xticks(2 ** jnp.arange(4, 10), [str(i) for i in 2 ** jnp.arange(4, 10)]);
plt.minorticks_off()
plt.legend(frameon=False)
plt.ylim(-3.45, -1.1)

# -

# Second plot
plt.style.use(['science','no-latex'])
plt.rcParams['legend.frameon'] = True
plt.rc('axes', labelsize=20)   # Axis labels
plt.rc('xtick', labelsize=16)  # X-axis tick labels
plt.rc('ytick', labelsize=16)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 2
plt.rc('legend',fontsize=18)

# +
key = "mmds"
res2 = {key: value for (key, value) in res.items() if key in ["BB_SVGD_ES", "MC SVGD"]}

# Create subplots with one row and n_strategies columns
fig, axs = plt.subplots(1, len(res2.keys()), figsize=(2 * len(res2.keys()), 1.5))

# Loop over each strategy and plot the heatmap
for i, strategy in enumerate(res2.keys()):
    # Extract the data for the current strategy
    pls = jnp.array(res2[strategy][key])
    pls = jnp.mean(pls, axis=1).reshape(len(subpops), -1)
    pls = jnp.log10(jnp.abs(pls))
    
    # Create heatmap on the corresponding subplot
    ax = sns.heatmap(pls, xticklabels=n_particles, yticklabels=subpops, vmin=-2.9, vmax=-1.1, ax=axs[i], cbar=False)
    
    # Customize ticks and title
    style = cmapper(strategy)
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticklabels(n_particles, rotation=90, fontsize=16)
    ax.set_yticklabels(subpops, rotation=0, fontsize=16)
    ax.set_title(style["name"], fontsize=14)

# Create a colorbar based on the last heatmap
cbar = fig.colorbar(axs[-1].collections[0], ax=axs, location='right', fraction=0.03, pad=0.02)
cbar.ax.tick_params(labelsize=14)
cbar.set_label("MMD", fontsize=14)

# Set common labels
fig.text(0.5, -0.35, 'N. particles', ha='center', fontsize=16)  # X-axis label for the entire figure
fig.text(-0.02, 0.5, 'Subpop. size', va='center', rotation='vertical', fontsize=16)  # Y-axis label for the entire figure
# -


