import seaborn as sns
import wandb
import torch
import matplotlib.pyplot as plt
import itertools
import numpy as np
import PIL

from einops import rearrange
from torchvision.utils import make_grid


def get_images(data_tensor):
    # imgs = make_grid((data_tensor + 1) / 2, nrow=4)
    imgs = make_grid(data_tensor, nrow=4)
    fig, ax = plt.subplots()
    # imgs = (imgs + 1) / 2
    ax.imshow(imgs.permute(1, 2, 0))
    ax.axis("off")
    return fig, ax


def get_wandb_heatmap(tensor_info, cmap="coolwarm", vmin=0, vmax=2):
    array_info = tensor_info.detach().to("cpu").numpy().transpose()
    dim_sz, time_sz = array_info.shape
    f, ax = plt.subplots(figsize=(10, max(4, 0.3125 * (dim_sz + 1))))

    ax = sns.heatmap(
        np.concatenate([array_info.mean(0)[None, :], array_info]),
        xticklabels=[round(i / time_sz, 2) for i in range(0, time_sz)],
        yticklabels=["mean"] + [f"#{i}" for i in range(dim_sz)],
        ax=ax,
        cmap=cmap,
        center=(vmin + vmax) / 2,
        vmin=vmin,
        vmax=vmax,
    )
    ax.set(xlabel="Time fraction", ylabel="Dimension")
    plt.locator_params(axis="x", nbins=10)
    # plt.tight_layout()
    return wandb.Image(f)


def get_figure(bounds=(-10.0, 10.0)):
    fig, ax = plt.subplots(1, figsize=(16, 16))
    ax.axis("off")
    ax.set_autoscale_on(False)
    ax.set_xlim([bounds[0], bounds[1]])
    ax.set_ylim([bounds[0], bounds[1]])
    return fig, ax


def get_figures(n=3, m=1, bounds=(-10.0, 10.0)):
    fig, axs = plt.subplots(n, m, figsize=(16 * m, 16 * n))
    for i in range(n):
        for j in range(m):
            if m == 1:
                if n == 1:
                    ax = axs
                else:
                    ax = axs[i]
            else:
                if n == 1:
                    ax = axs[j]
                else:
                    ax = axs[i, j]

            ax.set_autoscale_on(False)
            ax.set_xlim([bounds[0], bounds[1]])
            ax.set_ylim([bounds[0], bounds[1]])
    return fig, axs


def plot_contours(
    log_prob,
    ax=None,
    bounds=(-10.0, 10.0),
    grid_width_n_points=200,
    n_contour_levels=50,
    log_prob_min=-1000.0,
    device=torch.device("cuda"),
):
    """Plot contours of a log_prob_func that is defined on 2D"""
    if ax is None:
        fig, ax = plt.subplots(1)
    x_points_dim1 = torch.linspace(bounds[0], bounds[1], grid_width_n_points)
    x_points_dim2 = x_points_dim1
    x_points = torch.tensor(list(itertools.product(x_points_dim1, x_points_dim2)))
    log_p_x = log_prob(x_points.to(device)).detach().cpu()
    log_p_x = torch.clamp_min(log_p_x, log_prob_min)
    log_p_x = log_p_x.reshape((grid_width_n_points, grid_width_n_points))
    x_points_dim1 = x_points[:, 0].reshape((grid_width_n_points, grid_width_n_points)).numpy()
    x_points_dim2 = x_points[:, 1].reshape((grid_width_n_points, grid_width_n_points)).numpy()
    if n_contour_levels:
        ax.contour(x_points_dim1, x_points_dim2, log_p_x, levels=n_contour_levels)
    else:
        ax.contour(x_points_dim1, x_points_dim2, log_p_x)


def plot_black_contours(
    log_prob,
    ax=None,
    bounds=(-10.0, 10.0),
    grid_width_n_points=200,
    n_contour_levels=50,
    log_prob_min=-1000.0,
    device=torch.device("cuda"),
):
    """Plot contours of a log_prob_func that is defined on 2D"""
    if ax is None:
        fig, ax = plt.subplots(1)
    x_points_dim1 = torch.linspace(bounds[0], bounds[1], grid_width_n_points)
    x_points_dim2 = x_points_dim1
    x_points = torch.tensor(list(itertools.product(x_points_dim1, x_points_dim2)))
    HUGE_VALUE = 0
    huge_log_p_x = np.ones((grid_width_n_points, grid_width_n_points)) * HUGE_VALUE
    x_points_dim1 = x_points[:, 0].reshape((grid_width_n_points, grid_width_n_points)).numpy()
    x_points_dim2 = x_points[:, 1].reshape((grid_width_n_points, grid_width_n_points)).numpy()
    if n_contour_levels:
        ax.contour(x_points_dim1, x_points_dim2, huge_log_p_x, levels=n_contour_levels)
    else:
        ax.contour(x_points_dim1, x_points_dim2, huge_log_p_x)


def plot_samples(samples, ax=None, bounds=(-10.0, 10.0), alpha=0.5, color=None, size=10, marker="o"):
    if ax is None:
        fig, ax = plt.subplots(1)
    samples = torch.clamp(samples, bounds[0], bounds[1])
    samples = samples.cpu().detach()
    ax.scatter(samples[:, 0], samples[:, 1], alpha=alpha, marker=marker, s=size, c=color)


def plot_kde(samples, ax=None, bounds=(-10.0, 10.0)):
    if ax is None:
        fig, ax = plt.subplots(1)
    samples = samples.cpu().detach()
    sns.kdeplot(x=samples[:, 0], y=samples[:, 1], cmap="Blues", fill=True, ax=ax, clip=bounds)


def viz_many_well(mw_energy, samples=None, num_samples=5000):
    if samples is None:
        samples = mw_energy.sample(num_samples)

    x13 = samples[:, 0:3:2].detach().cpu()
    fig_samples_x13, ax_samples_x13 = viz_sample2d(x13, "samples", f"distx13.png", lim=3)
    fig_kde_x13, ax_kde_x13 = viz_kde2d(x13, "kde", f"kdex13.png", lim=3)

    lim = 3
    alpha = 0.8
    n_contour_levels = 20

    def logp_func(x_2d):
        x = torch.zeros((x_2d.shape[0], mw_energy.data_ndim)).to(mw_energy.device)
        x[:, 0] = x_2d[:, 0]
        x[:, 2] = x_2d[:, 1]
        return -mw_energy.energy(x).detach().cpu()

    x13 = samples[:, 0:3:2]
    contour_img_path = f"contourx13.png"
    fig_contour_x13, ax_contour_x13 = viz_contour_sample2d(
        x13, contour_img_path, logp_func, lim=lim, alpha=alpha, n_contour_levels=n_contour_levels
    )

    x23 = samples[:, 1:3].detach().cpu()
    fig_samples_x23, ax_samples_x23 = viz_sample2d(x23, "samples", f"distx23.png", lim=3)
    fig_kde_x23, ax_kde_x23 = viz_kde2d(x23, "kde", f"kdex23.png", lim=3)

    def logp_func(x_2d):
        x = torch.zeros((x_2d.shape[0], mw_energy.data_ndim)).to(mw_energy.device)
        x[:, 1] = x_2d[:, 0]
        x[:, 2] = x_2d[:, 1]
        return -mw_energy.energy(x).detach().cpu()

    x23 = samples[:, 1:3]
    contour_img_path2 = f"contourx23.png"
    fig_contour_x23, ax_contour_x23 = viz_contour_sample2d(
        x23, contour_img_path2, logp_func, lim=lim, alpha=alpha, n_contour_levels=n_contour_levels
    )

    return (
        fig_samples_x13,
        ax_samples_x13,
        fig_kde_x13,
        ax_kde_x13,
        fig_contour_x13,
        ax_contour_x13,
        fig_samples_x23,
        ax_samples_x23,
        fig_kde_x23,
        ax_kde_x23,
        fig_contour_x23,
        ax_contour_x23,
    )


def traj_plot1d(traj_len, samples, xlabel, ylabel, title="", fsave="img.png"):
    samples = rearrange(samples, "t b d -> b t d").cpu()
    inds = np.linspace(0, samples.shape[1], traj_len, endpoint=False, dtype=int)
    samples = samples[:, inds]
    plt.figure()
    for i, sample in enumerate(samples):
        plt.plot(np.arange(traj_len), sample.flatten(), marker="x", label=f"sample {i}")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(fsave)
    plt.close()


########### 2D plot
def viz_sample2d(points, title, fsave, lim=7.0, sample_num=50000):
    fig, ax = plt.subplots(1, 1, figsize=(7, 7))
    if title is not None:
        ax.set_title(title)
    ax.plot(
        points[:sample_num, 0],
        points[:sample_num, 1],
        linewidth=0,
        marker=".",
        markersize=1,
    )
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    return fig, ax


def viz_kde2d(points, title, fname, lim=7.0, sample_num=2000):
    fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=200)
    if title is not None:
        ax.set_title(title)
    sns.kdeplot(x=points[:sample_num, 0], y=points[:sample_num, 1], cmap="coolwarm", fill=True, ax=ax)
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    return fig, ax


def viz_coutour_with_ax(ax, log_prob_func, lim=3.0, n_contour_levels=None):
    grid_width_n_points = 100
    log_prob_min = -1000.0
    x_points_dim1 = torch.linspace(-lim, lim, grid_width_n_points)
    x_points_dim2 = x_points_dim1
    x_points = torch.tensor(list(itertools.product(x_points_dim1, x_points_dim2)))
    log_p_x = log_prob_func(x_points).detach().cpu()
    log_p_x = torch.clamp_min(log_p_x, log_prob_min)
    log_p_x = log_p_x.reshape((grid_width_n_points, grid_width_n_points))
    x_points_dim1 = x_points[:, 0].reshape((grid_width_n_points, grid_width_n_points)).numpy()
    x_points_dim2 = x_points[:, 1].reshape((grid_width_n_points, grid_width_n_points)).numpy()
    if n_contour_levels:
        ax.contour(x_points_dim1, x_points_dim2, log_p_x, levels=n_contour_levels)
    else:
        ax.contour(x_points_dim1, x_points_dim2, log_p_x)


def viz_contour_sample2d(points, fname, log_prob_func, lim=3.0, alpha=0.7, n_contour_levels=None):
    fig, ax = plt.subplots(1, 1, figsize=(7, 7))

    viz_coutour_with_ax(ax, log_prob_func, lim=lim, n_contour_levels=n_contour_levels)

    samples = torch.clamp(points, -lim, lim)
    samples = samples.cpu().detach()
    ax.plot(samples[:, 0], samples[:, 1], linewidth=0, marker=".", markersize=1.5, alpha=alpha)

    return fig, ax


def fig2wandb(fig):
    fig.canvas.draw()
    pil_image = PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
    return wandb.Image(pil_image)


def plot_step(energy, gfn_model, name, perfect, data_size, lambda_discretizer, device, gm=None):
    batch_size = data_size
    if perfect:
        samples = energy.sample(batch_size)
        if energy.is_gan:
            objects = energy.sample_objects(batch_size)
        else:
            objects = samples
    else:
        samples = gfn_model.sample(batch_size, lambda_discretizer, energy.log_reward)
        if energy.is_gan:
            objects = energy.generate(samples)

    if energy.is_many_well:
        fig_samples_x13, _, fig_kde_x13, _, fig_contour_x13, _, fig_samples_x23, _, fig_kde_x23, _, fig_contour_x23, _ = viz_many_well(
            energy, samples
        )

        fig_samples_x13.savefig(f"{name}samplesx13.pdf", bbox_inches="tight")
        fig_samples_x23.savefig(f"{name}samplesx23.pdf", bbox_inches="tight")

        fig_kde_x13.savefig(f"{name}kdex13.pdf", bbox_inches="tight")
        fig_kde_x23.savefig(f"{name}kdex23.pdf", bbox_inches="tight")

        fig_contour_x13.savefig(f"{name}contourx13.pdf", bbox_inches="tight")
        fig_contour_x23.savefig(f"{name}contourx23.pdf", bbox_inches="tight")

        return {
            "visualization/contourx13": fig2wandb(fig_contour_x13),
            "visualization/contourx23": fig2wandb(fig_contour_x23),
            "visualization/kdex13": fig2wandb(fig_kde_x13),
            "visualization/kdex23": fig2wandb(fig_kde_x23),
            "visualization/samplesx13": fig2wandb(fig_samples_x13),
            "visualization/samplesx23": fig2wandb(fig_samples_x23),
        }

    elif energy.is_gan:
        visualizations = {}

        fig_samples, _ = viz_sample2d(samples.detach().cpu(), "samples", f"samplesx.png", lim=3)
        visualizations.update({"visualization/samples": fig2wandb(fig_samples)})

        fig_gan_objects, _ = get_images(objects[:16].detach().cpu())
        visualizations.update({"visualization/images": fig2wandb(fig_gan_objects)})
        return visualizations

    elif energy.data_ndim != 2:
        return {}

    else:
        gt_states = energy.sample(batch_size)

        fig_contour, ax_contour = get_figure(bounds=energy.bounds)
        fig_contour_with_gt, ax_contour_with_gt = get_figure(bounds=energy.bounds)
        fig_kde, ax_kde = get_figure(bounds=energy.bounds)
        fig_kde_overlay, ax_kde_overlay = get_figure(bounds=energy.bounds)

        plot_contours(energy.log_reward, ax=ax_contour, bounds=energy.bounds, n_contour_levels=150, device=device)
        plot_contours(energy.log_reward, ax=ax_contour_with_gt, bounds=energy.bounds, n_contour_levels=150, device=device)
        plot_kde(gt_states, ax=ax_kde_overlay, bounds=energy.bounds)
        plot_kde(samples, ax=ax_kde, bounds=energy.bounds)
        if gm:
            colormap = plt.cm.viridis(np.linspace(0, 1, energy.nmode))
            classes = gm.predict(samples.cpu().numpy())
            color = [colormap[c] for c in classes]
            gt_classes = gm.predict(gt_states.cpu().numpy())
            gt_color = [colormap[c] for c in gt_classes]
        plot_samples(samples, ax=ax_contour, bounds=energy.bounds, color=color, alpha=0.3)
        plot_samples(gt_states, ax=ax_contour_with_gt, bounds=energy.bounds, color=gt_color, alpha=0.3)
        plot_samples(samples, ax=ax_kde_overlay, bounds=energy.bounds, color=color, alpha=0.3)
        if gm:
            plot_samples(
                torch.from_numpy(gm.means_),
                ax=ax_contour,
                bounds=energy.bounds,
                color="red",
                alpha=1,
                marker="x",
                size=20,
            )
            plot_samples(energy.means.cpu(), ax=ax_contour, bounds=energy.bounds, color="green", alpha=1, marker="x", size=20)
        else:
            color = None

        fig_contour.savefig(f"{name}contour.pdf", bbox_inches="tight")
        fig_contour_with_gt.savefig(f"{name}contour_with_gt.pdf", bbox_inches="tight")
        fig_kde_overlay.savefig(f"{name}kde_overlay.pdf", bbox_inches="tight")
        fig_kde.savefig(f"{name}kde.pdf", bbox_inches="tight")
        return {
            "visualization/contour": fig2wandb(fig_contour),
            "visualization/contour_with_gt": fig2wandb(fig_contour_with_gt),
            "visualization/kde_overlay": fig2wandb(fig_kde_overlay),
            "visualization/kde": fig2wandb(fig_kde),
        }
