import os
import re
import pickle
import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.colors as mcolors
from matplotlib.legend_handler import HandlerBase
from matplotlib.colors import ListedColormap
from tqdm import tqdm
import ot
from scipy.stats import wasserstein_distance
from adaptive_svgd.examples import ode

if not os.path.isdir("plots"):
    os.makedirs("plots", exist_ok=True)

plt.rc("text", usetex=True)
plt.rc("legend", fontsize=6)
plt.rc("font", family="serif")
plt.rc("font", size=9)
plt.rc("axes", labelsize=8)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
plt.rc("xtick", labelsize=6)
plt.rc("ytick", labelsize=6)

print("Collecting results for ODE example...")
# check which runs have finished
exp_path = "experiments/ode"
Ms = [50, 100, 200]
finished_runs = {M: [] for M in Ms}
for file in os.listdir(exp_path):
    match = re.match(r"results_AdSVGD_M(\d+)_run(\d+)\.pkl", file)
    if match:
        M = int(match.group(1))
        run = int(match.group(2))
        finished_runs[M].append(run)

# calculate losses
dists_adaptive = {M: np.empty((len(finished_runs[M]), 4000)) for M in Ms}
dists_median = {M: np.empty((len(finished_runs[M]), 4000)) for M in Ms}
for M in Ms:
    for i, run in enumerate(
        tqdm(finished_runs[200], desc=f"M={M}")
    ):  # use only runs which finished for all M
        with open(os.path.join(exp_path, f"setup_run{run}.pkl"), "rb") as f:
            _, posterior_cov, posterior_mean = pickle.load(f)
        with open(
            os.path.join(exp_path, f"results_AdSVGD_M{M}_run{run}.pkl"), "rb"
        ) as f:
            x_hist_adaptive, _ = pickle.load(f)
        dists_adaptive[M][i] = [
            ot.gaussian.bures_wasserstein_distance(
                np.mean(xs, axis=0),
                posterior_mean,
                np.cov(xs, rowvar=False),
                posterior_cov,
            )
            for xs in x_hist_adaptive
        ]
        with open(
            os.path.join(exp_path, f"results_MedSVGD_M{M}_run{run}.pkl"), "rb"
        ) as f:
            x_hist_median, _ = pickle.load(f)
        dists_median[M][i] = [
            ot.gaussian.bures_wasserstein_distance(
                np.mean(xs, axis=0),
                posterior_mean,
                np.cov(xs, rowvar=False),
                posterior_cov,
            )
            for xs in x_hist_median
        ]
mean_dists_median = {M: np.mean(dists_median[M], axis=0) for M in Ms}
mean_dists_adaptive = {M: np.mean(dists_adaptive[M], axis=0) for M in Ms}
lower_dists_median = {
    M: np.mean(dists_median[M], axis=0)
    - 1.96 * np.std(dists_median[M], axis=0) / np.sqrt(len(dists_median[M]))
    for M in Ms
}
upper_dists_median = {
    M: np.mean(dists_median[M], axis=0)
    + 1.96 * np.std(dists_median[M], axis=0) / np.sqrt(len(dists_median[M]))
    for M in Ms
}
lower_dists_adaptive = {
    M: np.mean(dists_adaptive[M], axis=0)
    - 1.96 * np.std(dists_adaptive[M], axis=0) / np.sqrt(len(dists_adaptive[M]))
    for M in Ms
}
upper_dists_adaptive = {
    M: np.mean(dists_adaptive[M], axis=0)
    + 1.96 * np.std(dists_adaptive[M], axis=0) / np.sqrt(len(dists_adaptive[M]))
    for M in Ms
}

# calculate marginal variances
final_vars_adaptive = {M: np.empty((len(finished_runs[M]), 16)) for M in Ms}
final_vars_median = {M: np.empty((len(finished_runs[M]), 16)) for M in Ms}
for M in Ms:
    for i, run in enumerate(tqdm(finished_runs[200], desc=f"M={M}")):
        with open(
            os.path.join(exp_path, f"results_AdSVGD_M{M}_run{run}.pkl"), "rb"
        ) as f:
            x_hist_adaptive, _ = pickle.load(f)
        final_vars_adaptive[M][i] = np.var(x_hist_adaptive[-1], axis=0)
        with open(
            os.path.join(exp_path, f"results_MedSVGD_M{M}_run{run}.pkl"), "rb"
        ) as f:
            x_hist_median, _ = pickle.load(f)
        final_vars_median[M][i] = np.var(x_hist_median[-1], axis=0)

print("Plotting marginal variances...")
# plot marginal variances for M=200
# latex textwidth (total): 397.5 pt
# -> figure width: 0.48*397.5 pt = 190.8 pt = 2.64 inch
# height/width ratio: 0.78
# -> height: 0.78*2.64 inch = 2.06 inch
plt.figure(figsize=(2.64, 2.06))
plt.scatter(
    np.arange(1, 17),
    np.diag(posterior_cov),
    label="posterior",
    color="black",
    marker="x",
    zorder=10,
    alpha=0.7,
    s=12,
)
plt.errorbar(
    np.arange(1, 17),
    np.mean(final_vars_adaptive[200], axis=0),
    yerr=1.96
    * np.std(final_vars_adaptive[200], axis=0)
    / np.sqrt(len(final_vars_adaptive[200])),
    fmt="s",
    label="Ad-SVGD, $M=200$",
    color=plt.get_cmap("viridis")(0.8),
    markersize=3,
    capsize=3,
)
plt.errorbar(
    np.arange(1, 17),
    np.mean(final_vars_median[200], axis=0),
    yerr=1.96
    * np.std(final_vars_median[200], axis=0)
    / np.sqrt(len(final_vars_median[200])),
    fmt="o",
    label="Med-SVGD, $M=200$",
    color=plt.get_cmap("viridis")(0.32),
    markersize=3,
    capsize=3,
)
plt.yscale("log")
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.xticks(np.arange(1, 17, 3))
plt.legend()
plt.xlabel("Dimension")
plt.ylabel("Marginal variance")
plt.savefig("plots/marginal_var_plot.pdf", bbox_inches="tight", dpi=600)

print("Making loss plots...")
# make loss plot
# latex textwidth (total): 397.5 pt
# figure width: 0.48*397.5 pt = 190.8 pt = 2.64 inch
# height/width ratio: 0.78
# -> height: 0.78*2.64 inch = 2.06 inch
plt.figure(figsize=(2.64, 2.06))
viridis = plt.get_cmap("viridis")
colors = viridis(np.linspace(0, 0.8, 2 * len(Ms)))
steps = np.arange(len(mean_dists_median[200])) * 100
# Use distinct line styles and markers for each method and M value
line_styles = ["-.", "--", "-"]
markers = ["o", "s", "D"]
for idx, (M, line_style, marker) in enumerate(zip(Ms, line_styles, markers)):
    plt.plot(
        steps,
        mean_dists_median[M],
        label=f"Med-SVGD, $M={M}$",
        linestyle=line_style,
        color=colors[idx],
        linewidth=0.7,
    )
    plt.fill_between(
        steps,
        lower_dists_median[M],
        upper_dists_median[M],
        alpha=0.2,
        color=colors[idx],
    )
for idx, (M, line_style, marker) in enumerate(
    zip(Ms, line_styles, markers), start=len(Ms)
):
    plt.plot(
        steps,
        mean_dists_adaptive[M],
        label=f"Ad-SVGD, $M={M}$",
        linestyle=line_style,
        color=colors[idx],
        linewidth=0.7,
    )
    plt.fill_between(
        steps,
        lower_dists_adaptive[M],
        upper_dists_adaptive[M],
        alpha=0.2,
        color=colors[idx],
    )
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.yscale("log")
plt.xlabel("Iteration")
plt.xticks(np.arange(0, 4001, 1000) * 100)
plt.ylabel(r"$\mathcal{W}_2\big(\mathcal{N}(\hat{\mu}, \hat{\Sigma}),\pi\big)$")
plt.legend(framealpha=0.3)
plt.savefig("plots/loss_plot.pdf", bbox_inches="tight", dpi=300)

print("Collecting bandwidth evolutions for M=200...")
h_hists = np.empty((len(finished_runs[200]), 4000, 16))
for i, run in enumerate(
    tqdm(finished_runs[200], desc="M=200")
):  # use only runs which finished for all M
    with open(os.path.join(exp_path, f"results_AdSVGD_M200_run{run}.pkl"), "rb") as f:
        _, h_hist_adaptive = pickle.load(f)
    h_hists[i] = h_hist_adaptive
mean_hists = np.mean(h_hists, axis=0)
std_hists = np.std(h_hists, axis=0)

print("Making bandwidth plots...")
# latex textwidth (total): 397.5 pt
# figure width: 0.48*397.5 pt = 190.8 pt = 2.64 inch
# height/width ratio: 0.78
# -> height: 0.78*2.64 inch = 2.06 inch
plt.figure(figsize=(2.64, 2.06))
plt.errorbar(
    np.arange(1, 17),
    mean_hists[-1, :],
    1.96 * std_hists[-1, :] / np.sqrt(len(finished_runs[200])),
    fmt="o",
    markersize=3,
    capsize=3,
)
plt.xticks(np.arange(1, 17, 3))
plt.xlabel("Dimension")
plt.ylabel("Final bandwidth $h_i$")
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.savefig("plots/final_bandwidth_plot.pdf", bbox_inches="tight", dpi=600)

# latex textwidth (total): 397.5 pt
# figure width: 0.48*397.5 pt = 190.8 pt = 2.64 inch
# height/width ratio: 0.78
# -> height: 0.78*2.64 inch = 2.06 inch
fig, axs = plt.subplots(
    1, 2, figsize=(2.64, 2.06), width_ratios=[1, 0.03], gridspec_kw={"wspace": 0.05}
)
colors = viridis(np.linspace(0, 1.0, 16))
for i, (mean, color) in enumerate(zip(mean_hists.T, colors)):
    axs[0].plot(
        np.arange(len(mean)) * 100,
        mean,
        color=color,
        linewidth=0.7,
        label=f"$i={i + 1}$",
    )
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Bandwidth $h_i$")
axs[0].grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
pos = axs[1].get_position()
axs[1].set_position([pos.x0, pos.y0 + 0.06, pos.width, pos.height * 0.8])
cmap = ListedColormap(colors)
norm = mcolors.BoundaryNorm(list(range(1, 18)), cmap.N)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar = plt.colorbar(sm, ticks=np.linspace(1.5, 16.5, 16), cax=axs[1])
cbar.ax.set_yticklabels([str(i) for i in range(1, 17)])
cbar.ax.tick_params(which="minor", size=0)
cbar.ax.tick_params(which="minor", width=0)
cbar.ax.xaxis.set_label_position("top")
cbar.ax.set_xlabel("$i$")
plt.savefig("plots/bandwidth_evolution_plot.pdf", bbox_inches="tight", dpi=600)

print("Making posterior approximation plots...")
rng = np.random.default_rng(0)
Nx = 16
Ny = 256
No = 256
noise_cov = np.eye(No) * 0.001
prior_cov_scale = 50.0
KL_length = 512
N_posterior = 10000
(
    _,
    _,
    _,
    _,
    _,
    gp,
    _,
    _,
    _,
) = ode(rng, Nx, Ny, No, noise_cov, N_posterior, prior_cov_scale, KL_length)
with open(os.path.join(exp_path, "setup_run1.pkl"), "rb") as f:
    true_process, posterior_cov, posterior_mean = pickle.load(f)
posterior_sample = rng.multivariate_normal(
    mean=posterior_mean, cov=posterior_cov, size=N_posterior
)
with open(os.path.join(exp_path, "results_MedSVGD_M200_run1.pkl"), "rb") as f:
    x_hist_median, _ = pickle.load(f)
med_particles = x_hist_median[-1]
with open(os.path.join(exp_path, "results_AdSVGD_M200_run1.pkl"), "rb") as f:
    x_hist_adaptive, _ = pickle.load(f)
ad_particles = x_hist_adaptive[-1]
grid = np.linspace(0, 1, Ny + 2)
resulting_processes_ad: np.ndarray = gp(grid[1:], ad_particles)  # type: ignore
resulting_processes_med: np.ndarray = gp(grid[1:], med_particles)  # type: ignore
posterior_processes: np.ndarray = gp(grid[1:], posterior_sample)  # type: ignore
mean_ad = np.mean(resulting_processes_ad, axis=0)
mean_med = np.mean(resulting_processes_med, axis=0)
lower_ad = np.quantile(resulting_processes_ad, 0.05, axis=0)
upper_ad = np.quantile(resulting_processes_ad, 0.95, axis=0)
lower_med = np.quantile(resulting_processes_med, 0.05, axis=0)
upper_med = np.quantile(resulting_processes_med, 0.95, axis=0)
post_mean = np.mean(posterior_processes, axis=0)
post_lower = np.quantile(posterior_processes, 0.05, axis=0)
post_upper = np.quantile(posterior_processes, 0.95, axis=0)
# latex textwidth (total): 397.5 pt
# figure width: 0.6*397.5 pt = 238.5 pt = 3.3 inch
# height/width ratio: 0.4
# -> height: 0.4*3.3 inch = 1.32 inch
fig, axs = plt.subplots(1, 2, figsize=(3.3, 1.32), sharey=True)
axs[0].set_ylabel("$u(s)$")
axs[0].set_title("Ad-SVGD")
axs[1].set_title("Med-SVGD")
for ax in axs:
    ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
    ax.set_xlim(0, 1)
    ax.set_xlabel("$s$")
    (true_line,) = ax.plot(
        grid,
        np.hstack([0, true_process, 0]),
        alpha=0.15,
        label="true process" if ax == axs[0] else None,
        color="black",
        linewidth=0.7,
    )
    plt.fill_between(
        grid,
        np.hstack([0, post_lower]),
        np.hstack([0, post_upper]),
        alpha=0.1,
        label="posterior 90% confidence interval" if ax == axs[0] else None,
        color="orange",
        edgecolor="none",
    )
(line_ad,) = axs[0].plot(grid, np.hstack([0, mean_ad]), label="mean", linewidth=0.7)
axs[0].fill_between(
    grid,
    np.hstack([0, lower_ad]),
    np.hstack([0, upper_ad]),
    alpha=0.3,
    label="90% confidence interval",
    edgecolor="none",
)
axs[1].plot(grid, np.hstack([0, mean_med]), linewidth=0.7)
axs[1].fill_between(
    grid,
    np.hstack([0, lower_med]),
    np.hstack([0, upper_med]),
    alpha=0.3,
    edgecolor="none",
)
for ax in axs:
    (line_post,) = ax.plot(
        grid,
        np.hstack([0, post_mean]),
        label="posterior mean" if ax == axs[0] else None,
        color="orange",
        linestyle="--",
        linewidth=0.7,
    )


class HandlerLineWithFill(HandlerBase):
    def __init__(self, facecolor="blue", alpha=0.3, **kwargs):
        super().__init__(**kwargs)
        self.facecolor = facecolor
        self.alpha = alpha

    def create_artists(
        self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans
    ):
        # Centerline y
        y = ydescent + height / 2
        x1, x2 = xdescent, xdescent + width

        # Draw the filled region
        patch = plt.Polygon(
            [[x1, y - 3], [x1, y + 3], [x2, y + 3], [x2, y - 3]],
            facecolor=orig_handle.get_color(),
            alpha=self.alpha,
            edgecolor="none",
            transform=trans,
        )

        # Draw the line on top
        line = mlines.Line2D(
            [x1, x2],
            [y + 0.2, y + 0.2],
            color=orig_handle.get_color(),
            linestyle=orig_handle.get_linestyle(),
            transform=trans,
            linewidth=0.7,
        )

        return [patch, line]


fig.legend(
    [true_line, line_ad, line_post],
    ["true process", "posterior approximation", "posterior"],
    handler_map={
        line_ad: HandlerLineWithFill(alpha=0.3),
        line_post: HandlerLineWithFill(alpha=0.1),
    },
    loc="lower center",
    bbox_to_anchor=(0.5, -0.3),
    ncol=3,
)
plt.savefig("plots/posterior_approximation.pdf", bbox_inches="tight", dpi=600)

print("Collecting results for fixed bandwidth experiments...")
exp_path = "experiments/1d_gaussian_mixture"
Ms = [10, 20, 50, 100, 200, 500]
h_exponents = [-3, -2, -1, 0, 1, 2, 3]
rng = np.random.default_rng(0)
samples = np.concatenate(
    [rng.normal(-2, 1, size=int(1e5 / 3)), rng.normal(2, 1, size=int(2e5 / 3))]
)
results = {(M, h_exp): np.empty(100) for M, h_exp in itertools.product(Ms, h_exponents)}
for M, h_exp in tqdm(
    itertools.product(Ms, h_exponents), total=len(Ms) * len(h_exponents)
):
    for run in range(1, 101):
        h = 10**h_exp
        with open(
            f"{exp_path}/results_h{str(h).replace('.', '_')}_M{M}_run{run}.pkl",
            "rb",
        ) as f:
            x_hist = pickle.load(f)
        distance = wasserstein_distance(u_values=samples, v_values=x_hist[-1].flatten())
        results[(M, h_exp)][run - 1] = distance
mean_dists = {(M, h_exp): np.mean(dists) for (M, h_exp), dists in results.items()}
std_dists = {(M, h_exp): np.std(dists) for (M, h_exp), dists in results.items()}

print("Plotting fixed bandwidth results...")
Ms = [50, 200, 500]
# latex textwidth (total): 397.5 pt
# figure width: 0.5*397.5 pt = 198.75 pt = 2.75 inch
# height/width ratio: 0.78
# -> height: 0.78*2.75 inch = 2.145 inch
plt.figure(figsize=(2.75, 2.145))
viridis = plt.get_cmap("viridis")
colors = viridis(np.linspace(0, 0.8, len(Ms)))
for M, color in zip(Ms, colors):
    means = [mean_dists[(M, h_exp)] for h_exp in h_exponents]
    stds = [std_dists[(M, h_exp)] for h_exp in h_exponents]
    plt.scatter(
        [10**h_exp for h_exp in h_exponents],
        means,
        label=f"$M={M}$",
        color=color,
        s=10,
        zorder=10,
        alpha=0.7,
    )
plt.yscale("log")
plt.xscale("log")
plt.grid(True, which="major", linestyle="--", linewidth=0.5, alpha=0.7)
plt.legend()
plt.xlabel("Kernel bandwidth $h$")
plt.ylabel(r"$\mathcal{W}_1\big(\hat{\mu}_{\mathrm{nsteps}}, \pi\big)$")
plt.savefig("plots/fixed_bandwidth_plot.pdf", bbox_inches="tight", dpi=600)
