# script to generate uniform plots
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# keep in same order as uniform_script.py
mied_methods = ["MIED (Riesz)", "MIED (Laplace)", "MIED (Gaussian)",
                "Coin MIED (Riesz)", "Coin MIED (Laplace)", "Coin MIED (Gaussian)"]
mied_methods_fname = ["mied_riesz", "mied_laplace", "mied_gaussian",
                      "coin_mied_riesz", "coin_mied_laplace", "coin_mied_gaussian"]

msvgd_methods = ["MSVGD", "Coin MSVGD"]
msvgd_methods_fname = ["msvgd", "coin_msvgd"]

mied_results = np.load("results/uniform/energy_dist.npy")
msvgd_results = np.load("results/uniform/energy_dist_svgd.npy")

color_map = {
    "MSVGD": "C1",
    "MIED (Riesz)": "C2",
    "MIED (Laplace)": "C3",
    "MIED (Gaussian)": "C4",
    "Coin MIED (Riesz)" : "C0",
    "Coin MIED (Laplace)": "C0",
    "Coin MIED (Gaussian)": "C0",
    "Coin MSVGD": "C0",
}

search_lr = np.logspace(-5, 0, 20) # from uniform_script.py

# plot of energy distance vs learning rate
for ii in range(3):
    f, ax = plt.subplots(1, 1, figsize=(5, 4))
    method = mied_methods[ii+3] # coin
    ax.axhline(np.mean(mied_results[ii+3,0,:]), label=method, color=color_map[method])
    ax.fill_between(search_lr, np.min(mied_results[ii+3,0,:]), np.max(mied_results[ii+3,0,:]), alpha=0.2,
                color=color_map[method])
    method = mied_methods[ii]
    mean = np.mean(mied_results, axis=2)[ii, :]
    min = np.min(mied_results, axis=2)[ii,:]
    max = np.max(mied_results, axis=2)[ii,:]
    ax.plot(search_lr, mean, ".-", label=method, color=color_map[method], linestyle="dashed")
    ax.fill_between(search_lr, min, max, color=color_map[method], alpha=0.2)
    ax.set_yscale("log")
    ax.set_xscale("log")
    ax.set_ylim(5e-4, 1e0)
    ax.tick_params(axis='both', labelsize=18)
    ax.set_xlabel("Learning Rate", fontsize=18)
    ax.set_ylabel("Energy Distance", fontsize=18)
    ax.legend(prop={'size': 16}, loc='upper right')
    plt.margins(x=0)
    fname = "results/uniform/eds_vs_lr_" + mied_methods_fname[ii] + ".pdf"
    plt.savefig(fname, bbox_inches="tight", dpi=300)
    plt.show()

f, ax = plt.subplots(1, 1, figsize=(5, 4))
ii = 0
method = msvgd_methods[ii+1] # coin
ax.axhline(np.mean(msvgd_results[ii+1,0,:]), label=method, color=color_map[method])
ax.fill_between(search_lr, np.min(msvgd_results[ii+1,0,:]), np.max(msvgd_results[ii+1,0,:]), alpha=0.2,
            color=color_map[method])
method = msvgd_methods[ii]
mean = np.mean(msvgd_results, axis=2)[ii, :]
min = np.min(msvgd_results, axis=2)[ii,:]
max = np.max(msvgd_results, axis=2)[ii,:]
ax.plot(search_lr, mean, ".-", label=method, color=color_map[method], linestyle="dashed")
ax.fill_between(search_lr, min, max, color=color_map[method], alpha=0.2)
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_ylim(5e-4, 1e0)
ax.tick_params(axis='both', labelsize=18)
ax.set_xlabel("Learning Rate", fontsize=18)
ax.set_ylabel("Energy Distance", fontsize=18)
ax.legend(prop={'size': 16}, loc='upper right')
plt.margins(x=0)
fname = "results/uniform/eds_vs_lr_" + msvgd_methods_fname[ii] + ".pdf"
plt.savefig(fname, bbox_inches="tight", dpi=300)
plt.show()


fname = "results/uniform/energy_dist_vs_iter.csv"
results_df = pd.read_csv(fname, sep=',', header='infer')
csv_names = ['coin-msvgd-uniform - metrics.energy_dist',
             'coin-mied-riesz-uniform - metrics.energy_dist',
             'coin-mied-laplace-uniform - metrics.energy_dist',
             'coin-mied-gaussian-uniform - metrics.energy_dist',
             'msvgd-uniform - metrics.energy_dist',
             'mied-riesz-uniform - metrics.energy_dist',
             'mied-laplace-uniform - metrics.energy_dist',
             'mied-gaussian-uniform - metrics.energy_dist']

names = ["Coin MSVGD", "Coin MIED (Riesz)", "Coin MIED (Laplace)", "Coin MIED (Gaussian)",
         "MSVGD",  "MIED (Riesz)", "MIED (Laplace)", "MIED (Gaussian)"]

n_iter = 250
results = np.zeros((n_iter, len(names)))

f, ax = plt.subplots(1, 1, figsize=(12, 5))
# plot of energy distance vs iterations
for i, method in enumerate(csv_names):
    results[:, i] = results_df[csv_names[i]]
    ax.plot(results[:, i], label=names[i], zorder=10-i)

# ax.set_ylim(bottom=0)
ax.set_xlabel("Iterations", fontsize=16)
ax.set_ylabel("Energy Distance", fontsize=16)
ax.set_yscale("log")
ax.legend(ncol=2, prop={'size': 16})
ax.tick_params(axis='both', labelsize=16)
plt.savefig("results/uniform/eds_vs_iter.pdf", bbox_inches="tight", dpi=300)
plt.show()