# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# +
# %matplotlib inline
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform
# %env CUDA_VISIBLE_DEVICES=0
                                
import pickle
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scienceplots

from sves.benchmarks.utils.eval_utils import cmapper
from sves.strategies import SV_CMA_BB, OG_SVGD, GF_SVGD, MC_SVGD
from sves.kernels import RBF
from sves.benchmarks.lr import SVGDLogisticRegression, run_cfg_cma, run_cfg_gf, run_cfg_oes, run_cfg_og

plt.style.use(['science','no-latex'])
plt.rc('axes', labelsize=26)   # Axis labels
plt.rc('xtick', labelsize=22)  # X-axis tick labels
plt.rc('ytick', labelsize=22)  # Y-axis tick labels
plt.rcParams['lines.linewidth'] = 3
plt.rc('legend', fontsize=18)

devices = jax.devices("cuda")
devices
# -

# ## Paper plot
#
# For this we evaluate the best config on the test data and then plot it.

# Choose your dataset
dataset = "credit"
res_path = "data/paper_res/log_reg/"
os.makedirs(res_path, exist_ok=True)

# +
# Set up
device = devices[0]
rng = jax.random.PRNGKey(0)
npop = 8
subpopsize = 32
batch_size = 128
nrep = 10
num_generations = 1_000
n_val = 40
val_interval = num_generations // n_val
log_reg = SVGDLogisticRegression(dataset=dataset, batch_size=batch_size)

# Get strategies
if dataset == "covtype":
    strategies = [
        OG_SVGD(npop * subpopsize, RBF(0.445), num_iters=num_generations, num_dims=log_reg.dim, opt_name="adam", lrate_init=0.01),
        GF_SVGD(npop * subpopsize, RBF(1.), sigma_init=2.278, num_iters=num_generations, num_dims=log_reg.dim, opt_name="adam", lrate_init=.05),
        SV_CMA_BB(npop, subpopsize, RBF(0.778), num_dims=log_reg.dim, elite_ratio=0.4, sigma_init=0.15, num_iters=num_generations),
        MC_SVGD(npop, subpopsize, RBF(0.334), num_dims=log_reg.dim, lrate_init=.01, sigma_init=0.2, num_iters=num_generations),
    ]
elif dataset == "spam":
    strategies = [
        OG_SVGD(npop * subpopsize, RBF(.112), num_iters=num_generations, num_dims=log_reg.dim, opt_name="adam", lrate_init=.01),
        GF_SVGD(npop * subpopsize, RBF(1.), sigma_init=0.564, num_iters=num_generations, num_dims=log_reg.dim, opt_name="adam", lrate_init=.05),
        SV_CMA_BB(npop, subpopsize, RBF(.445), num_dims=log_reg.dim, elite_ratio=0.3, sigma_init=0.2, num_iters=num_generations),
        MC_SVGD(npop, subpopsize, RBF(.112), num_dims=log_reg.dim, lrate_init=.01, sigma_init=0.2, num_iters=num_generations),
    ]
elif dataset == "credit":
    strategies = [
        OG_SVGD(npop * subpopsize, RBF(0.001), num_iters=num_generations, num_dims=log_reg.dim, opt_name="adam", lrate_init=.1),
        GF_SVGD(npop * subpopsize, RBF(0.667), sigma_init=3.337, num_iters=num_generations, num_dims=log_reg.dim, opt_name="adam", lrate_init=.01),
        SV_CMA_BB(npop, subpopsize, RBF(0.001), num_dims=log_reg.dim, elite_ratio=0.3, sigma_init=0.35, num_iters=num_generations),
        MC_SVGD(npop, subpopsize, RBF(0.001), num_dims=log_reg.dim, lrate_init=.005, sigma_init=0.05, num_iters=num_generations),
    ]
else:
    raise ValueError("Dataset must be on of the above")
# -

# Make sure the mode is set to 'test'
for strategy in strategies:
    seeds = jax.random.split(rng, nrep)
    results = jax.vmap(
        lambda seed: log_reg.train(seed, strategy, device=device, mode="test", n_iter=num_generations, n_val=n_val)
    )(seeds)
    with open(f"{res_path}{dataset}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members_test.pkl", "wb") as handler:
        pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# Plot the performance on the test data.

# +
plot_data = {}
plt.figure(figsize=(5, 4.))
for strategy in strategies:
    with open(f"{res_path}{dataset}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members_test.pkl", "rb") as handler:
        data = pickle.load(handler)
        plot_data[f"{strategy.strategy_name}"] = {}
        for key in data.keys():
            plot_data[f"{strategy.strategy_name}"][key] = jnp.array(data[key])
    
    # Get means
    means = jnp.array(plot_data[strategy.strategy_name]["test_nll"]).T
    means_mean = jnp.mean(means, axis=0)
    means_std = jnp.std(means, axis=0) * 1.96 / jnp.sqrt(means.shape[0])
    
    # Plot
    t = jnp.arange(1, means.shape[1]+1) * val_interval
    style = cmapper(strategy.strategy_name)
    plt.plot(t, means_mean, label=style["name"], color=style["color"], linestyle=style["linestyle"], linewidth=3)
    plt.fill_between(t, means_mean-means_std, means_mean+means_std, alpha=.2, color=style["color"])

plt.xlabel("Num. Iter")
plt.ylabel("Test Accuracy")
#plt.legend();
#plt.ylim(.71, 1.015)
# -



# ## Hyperparam tuning
#
# NOTE: you must run the benchmark in validation mode here to test it later.

# +
# Hyper params
rng = jax.random.PRNGKey(123)
nrep = 5
npop = 8
subpopsize = 32
num_generations = 1_000
device = devices[0]
n_val = 40
val_interval = num_generations // n_val
dataset = "covtype"
batch_size = 128

# hyperparam tuning
adam_lrs = jnp.array([0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0])
kernel_vals = jnp.linspace(1e-3, 1., 10)
sigmas = jnp.linspace(.05, .5, 10)
sigmas_gf = jnp.linspace(.01, 5., 10)
elite_ratios = jnp.linspace(.05, .5, 10)
# -

# ### CMA-ES methods

# +
# Set up
log_reg = SVGDLogisticRegression(dataset=dataset, batch_size=batch_size)
strategy = SV_CMA_BB(1, 1, kernel=RBF(), num_iters=1, num_dims=1)

eval_cma = lambda er, sig, kw: run_cfg_cma(rng, npop, subpopsize, er, sig, kw, nrep, num_generations, log_reg, device)

for er in elite_ratios:  # Unfortunately this hyperparam cannot be mapped over so easily in CMA-ES because it will be used as fixed to index arrays
    res = jax.vmap(lambda kw: 
        jax.vmap(lambda sig: eval_cma(er, sig, kw))(sigmas)
    )(kernel_vals)
    res_acc = jnp.moveaxis(jnp.array(res["accuracy"]), 0, -2)
    res_logp = jnp.moveaxis(jnp.array(res["test_logp"]), 0, -2)
    
    for racci, rlogi, kw in zip(res_acc, res_logp, kernel_vals):
        for raccj, rlogj, sig in zip(racci, rlogi, sigmas):
            results = {"accuracy": raccj, "test_logp": rlogj}
            with open(f"{res_path}{log_reg.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{er}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)
                
# -

# ### OpenES based methods

# +
# Set up
log_reg = SVGDLogisticRegression(dataset=dataset, batch_size=batch_size)
strategy = MC_SVGD(1, 2, kernel=RBF(), num_iters=1, num_dims=1, sigma_init=1)

eval_oes = lambda lr, sig, kw: run_cfg_oes(rng, npop, subpopsize, lr, sig, kw, nrep, num_generations, log_reg, device)
res = jax.vmap(lambda kw: jax.vmap(
    lambda sig: jax.vmap(
        lambda lr: eval_oes(lr, sig, kw))(adam_lrs)
    )(sigmas)
)(kernel_vals)
res_acc = jnp.moveaxis(jnp.array(res["accuracy"]), 0, -2)
res_logp = jnp.moveaxis(jnp.array(res["test_logp"]), 0, -2)

for racci, rlogi, kw in zip(res_acc, res_logp, kernel_vals):
    for raccj, rlogj, sig in zip(racci, rlogi, sigmas):
        for racck, rlogk, lr in zip(raccj, rlogj, adam_lrs):
            results = {"accuracy": racck, "test_logp": rlogk}
            with open(f"{res_path}{log_reg.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

del res
del res_acc
del res_logp
del results
# -

# ### Adam-based methods

# +
# Set up
log_reg = SVGDLogisticRegression(dataset=dataset, batch_size=batch_size)
strategy = OG_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam")

eval_og = lambda lr, kw: run_cfg_og(rng, npop, subpopsize, lr, kw, nrep, num_generations, log_reg, device)
res = jax.vmap(lambda kw: 
    jax.vmap(lambda lr: eval_og(lr, kw))(adam_lrs)
)(kernel_vals[:5])
res_acc = jnp.moveaxis(jnp.array(res["accuracy"]), 0, -2)
res_logp = jnp.moveaxis(jnp.array(res["test_logp"]), 0, -2)

for racci, rlogi, kw in zip(res_acc, res_logp, kernel_vals):
    for raccj, rlogj, lr in zip(racci, rlogi, adam_lrs):
        results = {"accuracy": raccj, "test_logp": rlogj}
        with open(f"{res_path}{log_reg.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{lr}elite{kw}kernel.pkl", "wb") as handler:
            pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# free mem
del res
del res_acc
del res_logp
del results

# +
# Set up
log_reg = SVGDLogisticRegression(dataset=dataset, batch_size=batch_size)
strategy = GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam")

eval_gf = lambda lr, sig, kw: run_cfg_gf(rng, npop, subpopsize, sig, lr, kw, nrep, num_generations, log_reg, device)

for kw in kernel_vals:
    res = jax.vmap(lambda lr: 
        jax.vmap(lambda sig: eval_gf(lr, sig, kw))(sigmas_gf[5:])
    )(adam_lrs)
    res_acc = jnp.moveaxis(jnp.array(res["accuracy"]), 0, -2)
    res_logp = jnp.moveaxis(jnp.array(res["test_logp"]), 0, -2)
    
    for racci, rlogi, lr in zip(res_acc, res_logp, adam_lrs):
        for raccj, rlogj, sig in zip(racci, rlogi, sigmas_gf[5:]):
            results = {"accuracy": raccj, "test_logp": rlogj}
            with open(f"{res_path}{log_reg.name}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sig}sigma{lr}elite{kw}kernel.pkl", "wb") as handler:
                pickle.dump(results, handler, protocol=pickle.HIGHEST_PROTOCOL)

# free mem
del res
del res_acc
del res_logp
del results
# -

# #### Plotting

# +
strategies = [
    OG_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    GF_SVGD(1, RBF(), num_iters=1, num_dims=1, opt_name="adam"),
    SV_CMA_BB(1, 1, kernel=RBF(.1), num_iters=1, num_dims=1, elite_ratio=0.25, sigma_init=1),
    MC_SVGD(1, 2, kernel=RBF(1), num_iters=1, num_dims=1, sigma_init=1),
]

plot_data = {}
for i, strategy in enumerate(strategies):
    for kw in kernel_vals:
        for er in jnp.concatenate([elite_ratios, adam_lrs]):
            for sigma in jnp.concatenate([sigmas, sigmas_gf]):
                try:
                    with open(f"{res_path}{dataset}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{sigma}sigma{er}elite{kw}kernel.pkl", "rb") as handler:
                        data = pickle.load(handler)
                        plot_data[f"{strategy.strategy_name}|{er}|{sigma}|{kw}"] = {}
                        for key in data.keys():
                            plot_data[f"{strategy.strategy_name}|{er}|{sigma}|{kw}"][key] = jnp.array(data[key])
                
                except Exception as e:
                    try:
                        # Second try
                        with open(f"{res_path}{dataset}{strategy.strategy_name}{nrep}rep{num_generations}iter{npop}pop{subpopsize}members{er}elite{kw}kernel.pkl", "rb") as handler:
                            data = pickle.load(handler)
                            plot_data[f"{strategy.strategy_name}|{er}|{kw}"] = {}
                            for key in data.keys():
                                plot_data[f"{strategy.strategy_name}|{er}|{kw}"][key] = jnp.array(data[key])
                    
                    except Exception as e:
                        pass


# +
# Iterate solutions and get the best one
best = {s.strategy_name: {"val": [-100_000]} for s in strategies}

for resname in plot_data.keys():
    strategy_name = resname.split("|")[0]
    means = jnp.array(plot_data[resname]["test_logp"]).T / batch_size + jnp.array(plot_data[resname]["accuracy"]).T #/ batch_size
    means_mean = jnp.mean(means, axis=0)
    means_std = jnp.std(means, axis=0) * 1.96 / jnp.sqrt(means.shape[0])
   
    if means_mean[-1] >= best[strategy_name]["val"][-1]:
        best[strategy_name]["val"] = means_mean
        best[strategy_name]["name"] = strategy_name
        best[strategy_name]["std"] = means_std
        best[strategy_name]["cfg"] = resname

# Plot performance on val set
t = jnp.arange(1, means.shape[1]+1) * val_interval
plt.figure(figsize=(6, 4.5))
for b in best.keys():
    res = best[b]
    if res["name"]:
        style = cmapper(res["name"])
        plt.plot(t, res["val"], label=res["name"], color=style["color"])
        plt.fill_between(t, res["val"]-res["std"], res["val"]+res["std"], alpha=.2, color=style["color"])

plt.ylabel("Accuracy")
plt.xlabel("Iteration")
plt.legend();

print("Best results at:")
for s in best.keys():
    print(best[s]["cfg"])
# -


