import numpy as np
import matplotlib.pyplot as plt

from steering import SourceTemperingSampler, FKSampler
import torch
from scipy.stats import wasserstein_distance
import ot
from scipy.stats import norm

from utils import (
    log_p_tilde,
    unnormalized_pdf,
    normalization_constant,
    metropolis_hastings_nd,
)


def plot_fk_vs_st_histograms(
    model,
    reward_fn,
    fk_sampler,
    st_sampler,
    base_width,
    beta,
    batch_size,
    device,
    xlim=(-4, 4),
    save_path=None,
):
    # Make sure model is in eval mode
    model.eval()
    torch.no_grad().__enter__()

    # Get target sample distributions
    datasets = []
    samples_target = metropolis_hastings_nd(
        log_p=log_p_tilde,
        x0=np.array([1.5]),
        beta=beta,
        base_width=base_width,
        n_samples=batch_size,
    )
    a = np.ones(len(samples_target)) / len(samples_target)

    # Get FK sample idstribution
    _, x_fk = fk_sampler.sample()
    C = ot.dist(x_fk.cpu().numpy(), samples_target, metric='euclidean')**2
    W2_sq = ot.emd2(a, a, C)  # this is now correct
    W2 = np.sqrt(W2_sq)
    print(
        f"FK W Distance: {W2}"
    )
    datasets.append(x_fk.squeeze(0).cpu())

    # Get SPT sample distribution
    _, x_st = st_sampler.sample(
        n_iterations=100,
        batch_size=batch_size,
    )
    C = ot.dist(x_st.cpu().numpy(), samples_target, metric='euclidean')**2
    W2_sq = ot.emd2(a, a, C)  # this is now correct
    W2 = np.sqrt(W2_sq)
    print(
        f"SPT W Distance: {W2}"
    )
    datasets.append(x_st.squeeze(0).cpu())

    # Set matplotlib figure params for large plots
    plt.rcParams.update(
        {
            # Axis titles
            "axes.titlesize": 25,  # title of the axes
            "axes.labelsize": 25,  # x and y axis labels
            # Tick labels
            "xtick.labelsize": 25,
            "ytick.labelsize": 25,
            # Legend
            "legend.fontsize": 18,
            "legend.title_fontsize": 18,
            # General font size for other text
            "font.size": 16,
        }
    )

    color = "black"
    if base_width == 1e-3:
        color = "red"
    if base_width == 1e-2:
        color = "orange"
    if base_width == 1e-1:
        color = "green"

    file_names = [
        "FK",
        "SPT",
    ]

    r_color = "#000000"  # black fddf lines

    # Create fine grid for reward function
    x_grid = torch.linspace(xlim[0], xlim[1], 500, device=device)
    bin_width = 1 / 40
    x_min, x_max = -3, 3
    bins = np.arange(x_min, x_max + bin_width, bin_width)
    K = normalization_constant(beta, base_width)

    # First plot the histograms against a reward function
    for data, fname in zip(datasets, file_names):
        if fname == "FK":
            color = "orange"
        else:
            color = "green"
        plt.figure(figsize=(6, 4))
        # Plot histogram
        plt.hist(data[:, 0], bins=bins, density=True, alpha=0.7, color=color)

        # Plot reward function
        with torch.no_grad():
            rewards = reward_fn(x_grid.unsqueeze(-1), base_width).cpu().numpy()
            rewards = (rewards - rewards.min()) / (rewards.max() - rewards.min()) * 2
            plt.plot(
                x_grid.cpu().numpy(),
                rewards,
                color=r_color,
                linewidth=2,
                label="Reward Function",
            )

        plt.xlim(xlim)
        plt.xlabel("x")
        plt.ylabel("Density")
        plt.grid(alpha=0.2)
        plt.legend()
        plt.tight_layout()
        plt.xlim(-3, 3)

        if save_path is not None:
            plt.savefig(f"{save_path}{fname}_reward.png", dpi=300, bbox_inches="tight")
            plt.savefig(f"{save_path}{fname}_reward.pdf", dpi=300, bbox_inches="tight")
        plt.clf()
        plt.close()

    # Next plot the histograms against the pdf, normal scale
    for data, fname in zip(datasets, file_names):
        if fname == "FK":
            color = "orange"
        else:
            color = "green"
        plt.figure(figsize=(6, 4))
        # Plot histogram
        plt.hist(data[:, 0], bins=bins, density=True, alpha=0.7, color=color)

        # Plot pdf
        with torch.no_grad():
            rewards = (
                unnormalized_pdf(x_grid.unsqueeze(-1).cpu(), beta, base_width)
                .cpu()
                .numpy()
                / K
            )
            plt.plot(
                x_grid.cpu().numpy(),
                rewards,
                color=r_color,
                linewidth=2,
                label=r"$\tilde p(x)$",
            )

        plt.xlim(xlim)
        plt.xlabel("x")
        plt.ylabel("Density")
        plt.grid(alpha=0.2)
        plt.legend(loc="upper left")
        plt.tight_layout()
        plt.xlim(-3, 3)

        if save_path is not None:
            plt.savefig(f"{save_path}{fname}_pdf.png", dpi=300, bbox_inches="tight")
            plt.savefig(f"{save_path}{fname}_pdf.pdf", dpi=300, bbox_inches="tight")
        plt.clf()
        plt.close()
    # Next plot against pdf on a semi-log scale
    for data, fname in zip(datasets, file_names):
        if fname == "FK":
            color = "C0"
        else:
            color = "C1"
        plt.figure(figsize=(6, 4))
        # Plot histogram
        plt.hist(data[:, 0], bins=bins, density=True, alpha=0.7, color=color)

        # Plot PDF
        with torch.no_grad():
            rewards = (
                unnormalized_pdf(x_grid.unsqueeze(-1).cpu(), beta, base_width)
                .cpu()
                .numpy()
                / K
            )
            plt.plot(
                x_grid.cpu().numpy(),
                rewards,
                color=r_color,
                linewidth=2,
                label=r"$\tilde p(x)$",
            )
            plt.plot(x_grid.cpu().numpy(), norm.pdf(x_grid.cpu().numpy(),-1.5,.2)*.9+norm.pdf(x_grid.cpu().numpy(),1.5,.2)*.1, label=r"$p(x)$",linestyle="--",alpha=.8, color="C2")
        plt.xlim(xlim)
        plt.xlabel("x")
        plt.ylabel("Density")
        plt.grid(alpha=0.2)
        plt.legend(loc="upper left")
        plt.yscale("symlog", linthresh=1e-1)
        plt.tight_layout()
        plt.xlim(-3, 3)

        if save_path is not None:
            plt.savefig(
                f"{save_path}{fname}_semilog_pdf.png", dpi=300, bbox_inches="tight"
            )
            plt.savefig(
                f"{save_path}{fname}_semilog_pdf.pdf", dpi=300, bbox_inches="tight"
            )
        plt.clf()
        plt.close()


def plot_fk_vs_st_betas(
    model,
    reward_fn,
    sampler_params,
    fk_params,
    base_sigma,
    base_width,
    marginal_prob_std_fn,
    st_betas=(1.0, 100.0),
    batch_size=4096,
    device="cpu",
    bins=50,
    xlim=(-4, 4),
    save_path=None,
):
    model.eval()
    torch.no_grad().__enter__()

    # This sigma is copmuted in the FK steering, but needs to be pre-computed for SPT
    st_sigma = marginal_prob_std_fn(1, base_sigma).item()
    fk_sigma = base_sigma

    # Get the FK Steering Samples
    datasets = []
    for beta in st_betas:
        fk_sampler = FKSampler(
            model=model,
            reward_fn=lambda x: reward_fn(x, base_width),
            sigma=fk_sigma,
            **fk_params,
            lmbda=beta,  # override lambda
        )
        _, x_fk = fk_sampler.sample()
        datasets.append(x_fk.squeeze(0).cpu())

    # Get the SPT samples
    for beta in st_betas:
        st_sampler = SourceTemperingSampler(
            model=model,
            reward_fn=lambda x: reward_fn(x, base_width),
            beta=beta,
            sigma=st_sigma,
            **sampler_params,
        )
        _, x_st = st_sampler.sample(
            n_iterations=100,
            batch_size=batch_size,
        )
        datasets.append(x_st.squeeze(0).cpu())

    r_color = "#000000"  # black fddf lines
    colors = ["#DD8452", "#DD8452", "#55A868", "#55A868"]

    plt.rcParams.update(
        {
            "font.size": 16,
            "axes.titlesize": 16,
            "axes.labelsize": 14,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
        }
    )

    titles = [
        f"FK (β = {st_betas[0]}, Exploration)",
        f"FK (β = {st_betas[-1]}, Exploitation)",
        f"SPT (β = {st_betas[0]}, Exploration)",
        f"SPT (β = {st_betas[-1]}, Exploitation)",
    ]
    file_names = [
        f"FK_Exploration",
        f"FK_Exploitation",
        f"SPT_Exploration",
        f"SPT_Exploitation",
    ]
    # Create fine grid for reward function
    x_grid = torch.linspace(xlim[0], xlim[1], 500, device=device)
    bin_width = 1 / 20
    x_min, x_max = -3, 3
    bins = np.arange(x_min, x_max + bin_width, bin_width)
    # Plot everything
    for data, title, color, fname in zip(datasets, titles, colors, file_names):
        plt.figure(figsize=(6, 4))
        # Plot histogram
        plt.hist(data[:, 0], bins=bins, density=True, alpha=0.7, color=color)

        # Plot reward function
        with torch.no_grad():
            rewards = reward_fn(x_grid.unsqueeze(-1), base_width).cpu().numpy()
            rewards = (rewards - rewards.min()) / (rewards.max() - rewards.min()) * 2
            plt.plot(
                x_grid.cpu().numpy(),
                rewards,
                color=r_color,
                linewidth=2,
                label="Reward Function",
            )

        plt.title(title, fontsize=25)
        plt.xlim(xlim)
        plt.xlabel("x")
        plt.ylabel("Density")
        plt.grid(alpha=0.2)
        plt.legend()
        plt.tight_layout()
        plt.ylim(0, 4)
        plt.xlim(-4, 4)

        if save_path is not None:
            plt.savefig(f"{save_path}{fname}.png", dpi=300, bbox_inches="tight")
            plt.savefig(f"{save_path}{fname}.pdf", dpi=300, bbox_inches="tight")
        plt.clf()
        plt.close()


def plot_fk_vs_st_exploration(
    model,
    reward_fn,
    st_uncond,
    sampler_params,
    fk_sampler,
    base_sigma,
    marginal_prob_std_fn,
    st_betas=(1.0, 100.0),
    batch_size=4096,
    device="cpu",
    save_path="results/imgs/",
    grid_size=200,  # resolution for reward contour
):
    model.eval()
    torch.no_grad().__enter__()

    # This sigma is copmuted in the FK steering, but needs to be pre-computed for SPT
    st_sigma = marginal_prob_std_fn(1, base_sigma).item()

    # Get the unconditional distribution
    z = torch.randn(batch_size, sampler_params["data_dim"], device=device) * st_sigma
    x_uncond = st_uncond.transport(z.unsqueeze(0)).squeeze(0).cpu()

    # Transport using FK Steering (only beta=100 is graphed, both produced similar plots)
    _, x_fk = fk_sampler.sample()
    x_fk = x_fk.squeeze(0).cpu()

    # Transport using two different values of beta
    st_samples = []
    for beta in st_betas:
        st = SourceTemperingSampler(
            model=model,
            reward_fn=lambda x: reward_fn(x, 0.1),
            beta=beta,
            sigma=st_sigma,
            **sampler_params,
        )
        _, x = st.sample(
            n_iterations=100,
            batch_size=batch_size,
        )
        st_samples.append(x.cpu())

    datasets = [
        x_uncond,
        x_fk,
        st_samples[0],
        st_samples[1],
    ]

    x_grid = torch.linspace(-2, 2, grid_size)
    y_grid = torch.linspace(-2, 2, grid_size)
    X, Y = torch.meshgrid(x_grid, y_grid, indexing="ij")
    XY = torch.stack([X.flatten(), Y.flatten()], dim=-1).to(device)
    Z = reward_fn(XY, 0.1).reshape(grid_size, grid_size).cpu()

    titles = [
        "Unconditional SPT",
        "FK steering",
        f"SPT (β = {st_betas[0]}, Exploration)",
        f"SPT (β = {st_betas[1]}, Exploitation)",
    ]
    plt.rcParams.update(
        {
            "font.size": 18,
            "axes.titlesize": 18,
            "axes.labelsize": 16,
            "xtick.labelsize": 14,
            "ytick.labelsize": 14,
        }
    )
    colors = [
        "#4C72B0",  # blue
        "#DD8452",  # orange
        "#55A868",  # green
        "#C44E52",  # red
    ]
    file_names = [
        f"Unconditional_SPT_2",
        f"FK_steering_2",
        f"SPT_Exploration_2",
        f"SPT_Exploitation_2",
    ]
    # Plot everything
    for data, title, color, fname in zip(datasets, titles, colors, file_names):
        plt.clf()
        # Reward contour
        plt.contourf(X, Y, Z, levels=50, cmap="cividis")
        # Scatter points on top
        plt.scatter(
            data[:, 0],
            data[:, 1],
            s=5,
            alpha=0.8,
            color=color,
            edgecolors="white",
            linewidths=0.2,
        )
        plt.title(title, fontsize=25)
        plt.grid(alpha=0.2)
        plt.xlim(-2, 2)
        plt.ylim(-2, 2)
        plt.tight_layout()

        # Save both PDF and PNG
        plt.savefig(f"{save_path}{fname}.pdf", dpi=300, bbox_inches="tight")
        plt.savefig(f"{save_path}{fname}.png", dpi=300, bbox_inches="tight")


def plot_wasserstein():
    # Make the plot have very big font:
    plt.rcParams.update(
        {
            # Axis titles
            "axes.titlesize": 25,  # title of the axes
            "axes.labelsize": 25,  # x and y axis labels
            # Tick labels
            "xtick.labelsize": 25,
            "ytick.labelsize": 25,
            # Legend
            "legend.fontsize": 18,
            "legend.title_fontsize": 18,
            # General font size for other text
            "font.size": 16,
        }
    )
    plt.figure(figsize=(6.6, 4.45))

    # If we created a larger list, just use the last 15 for consistency
    n = 15
    fk, spt = np.load("results/wasserstein.npy")
    ts = np.logspace(-6, -1, 50)[-n:]
    fk = fk[-n:]
    spt = spt[-n:]

    # Plot everything!
    plt.plot(ts, fk, lw=4)
    plt.scatter(ts, fk, label="FK Steering", s=95, marker="*")

    plt.plot(ts, spt, lw=4)
    plt.scatter(ts, spt, label="SPT (Ours)", s=75)

    plt.legend(loc="upper right")
    plt.xscale("log")
    plt.ylabel(r"Distance")
    plt.xlabel(r"$\omega^2$")
    plt.tight_layout()
    # Save both pdf and png for future papers.
    plt.savefig("results/imgs/wasserstein_small.png")
    plt.savefig("results/imgs/wasserstein_small.pdf")
