import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import wandb
from einops import rearrange
from sklearn.manifold import TSNE
from torchvision.utils import make_grid
from umap import UMAP
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from experiments.neural_datasets.inr_utils import make_image_grid


def scatter_plot_sdf(coords, sdf, ax):
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    # NOTE: We have to plot 0, 2, 1 for the shapes to show upright
    ax.scatter(coords[:, 0], coords[:, 2], coords[:, 1], c=sdf, s=5)
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_zlim(-1, 1)


def wandb_scatter_plot_sdf(coords, sdf):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(121, projection='3d')
    scatter_plot_sdf(coords, sdf, ax)
    return wandb.Image(fig)


def wandb_scatter_plot_gt_pred_sdf(gt_coords, gt_sdf, pred_coords, pred_sdf):
    fig = plt.figure(figsize=(20, 10))
    ax1 = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122, projection='3d')
    scatter_plot_sdf(gt_coords, gt_sdf, ax=ax1)
    scatter_plot_sdf(pred_coords, pred_sdf, ax=ax2)
    return wandb.Image(fig)


def wandb_embeddings_table(model, dataset, random_indices, images):
    embeddings = model._get_embeddings().cpu()[random_indices].flatten(1).numpy()
    labels = dataset.labels[random_indices].numpy()

    df = pd.DataFrame(
        data={
            "image": [wandb.Image(img) for img in images.numpy()],
            "target": [str(label) for label in labels],
            # 'data': [emb for emb in embeddings],
            **{f"d_{i}": embeddings[:, i] for i in range(embeddings.shape[1])},
        }
    )
    return wandb.Table(data=df, columns=list(df.columns))


def plot_2d_embeddings(embedded, labels, class_names):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.scatter(
        embedded[:, 0],
        embedded[:, 1],
        c=labels,
        cmap="tab10",
    )
    # colormap = plt.cm.tab10(range(len(class_names)))

    # for label in np.unique(labels):
    #     indices = labels == label
    #     ax.scatter(
    #         embedded[indices][:, 0],
    #         embedded[indices][:, 1],
    #         c=colormap[label],
    #         label=class_names[label],
    #     )
    # ax.legend()
    return fig, ax


def plot_2d_embeddings_with_images(embedded, images, zoom=0.5):
    fig, ax = plt.subplots()
    # __import__('IPython').embed()
    for point, image in zip(embedded, images):
        coordinates = (point[0], point[1])
        im = OffsetImage(image, zoom=zoom)
        ab = AnnotationBbox(im, coordinates, xycoords="data", frameon=False)
        ax.add_artist(ab)
        ax.update_datalim([coordinates])
        ax.autoscale()

    return fig, ax


def wandb_tsne(model, dataset, class_names, random_indices):
    embedded, labels = compute_tsne(model, dataset, random_indices)
    fig, ax = plot_2d_embeddings(embedded, labels, class_names)
    return fig


def wandb_tsne_with_images(model, dataset, random_indices, images):
    embedded, _ = compute_tsne(model, dataset, random_indices)
    fig, ax = plot_2d_embeddings_with_images(embedded, images.numpy())
    return fig


def plot_and_save_tsne(model, dataset, random_indices, class_names, file_name):
    embedded, labels = compute_tsne(model, dataset, random_indices)
    fig, ax = plot_2d_embeddings(embedded, labels, class_names)
    plt.savefig(file_name)
    plt.close(fig)


def compute_tsne(model, dataset, random_indices):
    """Plot t-SNE of the embeddings"""
    # NOTE: If we have more than one embedding per image, we concatenate them before
    # plotting.
    embeddings = model._get_embeddings().cpu()[random_indices].flatten(1).numpy()
    embedded = TSNE(n_components=2).fit_transform(embeddings)
    labels = dataset.labels[random_indices].numpy()

    return embedded, labels


def wandb_umap(model, dataset, class_names, random_indices):
    embedded, labels = compute_umap(model, dataset, random_indices)
    fig, ax = plot_2d_embeddings(embedded, labels, class_names)
    return fig


def compute_umap(model, dataset, random_indices):
    """Plot t-SNE of the embeddings"""
    # NOTE: If we have more than one embedding per image, we concatenate them before
    # plotting.
    embeddings = model._get_embeddings().cpu()[random_indices].flatten(1).numpy()
    embedded = UMAP(n_neighbors=10, min_dist=0.1, metric="correlation").fit_transform(
        embeddings
    )
    labels = dataset.labels[random_indices].numpy()

    return embedded, labels


def wandb_image_grid(images, num_images, im_size, title=None):
    grid_img = _get_image_grid(images, num_images, im_size)
    return wandb.Image(grid_img, caption=title)


def plot_image_grid(images, num_images, im_size, file_name, title=None):
    grid_img = _get_image_grid(images, num_images, im_size)
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.imshow(grid_img)
    if title is not None:
        ax.set_title(title)
    plt.savefig(file_name)
    plt.close(fig)


def _get_image_grid(images, num_images, im_size):
    grid_img = make_grid(
        rearrange(
            images.detach(),
            "(b h w) c -> b c h w",
            b=num_images,
            h=im_size[0],
            w=im_size[1],
        )
    )
    grid_img = grid_img.permute(1, 2, 0).cpu().numpy()
    return grid_img


def wandb_embedding_interpolation(model, im_size, title=None):
    grid_img = _get_embedding_interpolation(model, im_size)
    return wandb.Image(grid_img, caption=title)


def plot_embedding_interpolation(model, im_size, file_name, title=None):
    grid_img = _get_embedding_interpolation(model, im_size)
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.imshow(grid_img)
    if title is not None:
        ax.set_title(title)
    plt.savefig(file_name)
    plt.close(fig)


def _get_embedding_interpolation(model, im_size, interpolation_steps=9):
    model_embeddings = model._get_embeddings()
    emb_0 = model_embeddings[0]
    emb_1 = model_embeddings[1]

    interpolated_embeddings = torch.stack(
        [
            emb_0 * (1 - alpha) + emb_1 * alpha
            for alpha in np.linspace(0, 1, interpolation_steps)
        ],
        dim=0,
    )
    inputs = make_image_grid(im_size, interpolated_embeddings.shape[0])
    interpolated_embeddings = (
        interpolated_embeddings.unsqueeze(1)
        .repeat(1, inputs.shape[1], 1, 1)
        .flatten(0, 1)
    )
    inputs = inputs.flatten(0, 1)
    inputs = inputs.to(model_embeddings.device)
    with torch.no_grad():
        model.eval()
        hidden_chunk = 0 if model.shared_hidden_embeddings else model.num_hidden
        hidden_embeddings = (
            interpolated_embeddings[:, :hidden_chunk] if hidden_chunk > 0 else None
        )
        out_embeddings = interpolated_embeddings[:, hidden_chunk:]
        interpolated_images = model.forward_embeddings(
            inputs, hidden_embeddings, out_embeddings
        )

    grid_img = make_grid(
        rearrange(
            interpolated_images,
            "(b h w) c -> b c h w",
            b=interpolation_steps,
            h=im_size[0],
            w=im_size[1],
        ),
        nrow=interpolation_steps,
    )
    grid_img = grid_img.permute(1, 2, 0).cpu().numpy()
    return grid_img


def wandb_embedding_cutmix(model, im_size, title=None):
    grid_img = _cutmix(model, im_size)
    return wandb.Image(grid_img, caption=title)


def _cutmix(model, im_size):
    model_embeddings = model._get_embeddings()
    emb_0 = model_embeddings[0]
    emb_1 = model_embeddings[1]

    num_images = emb_0.shape[0] + 1
    cutmix_embeddings = torch.zeros(
        (emb_0.shape[0] + 1, emb_0.shape[0], emb_0.shape[1]),
        device=emb_0.device,
    )
    cutmix_embeddings[0] = emb_0
    cutmix_embeddings[-1] = emb_1
    # perm = torch.randint(0, 2, (num_images, emb_0.shape[0], emb_0.shape[1]), device=emb_0.device)
    # perm = perm.bool()
    for i in range(1, emb_0.shape[0]):
        theta = 1 - i / emb_0.shape[0]
        dist = torch.distributions.Bernoulli(
            torch.tensor([[theta]]).repeat((emb_0.shape[0], emb_0.shape[1]))
        )
        perm = dist.sample().bool().to(emb_0.device)
        cutmix_embeddings[i] = perm * emb_0 + (~perm) * emb_1
        # cutmix_embeddings[i] = torch.cat(
        #     [
        #         emb_0[: -i],
        #         emb_1[-i:],
        #     ],
        #     dim=0,
        # )

    inputs = make_image_grid(im_size, cutmix_embeddings.shape[0])
    cutmix_embeddings = (
        cutmix_embeddings.unsqueeze(1).repeat(1, inputs.shape[1], 1, 1).flatten(0, 1)
    )
    inputs = inputs.flatten(0, 1)
    inputs = inputs.to(model_embeddings.device)
    with torch.no_grad():
        model.eval()
        hidden_chunk = 0 if model.shared_hidden_embeddings else model.num_hidden
        hidden_embeddings = (
            cutmix_embeddings[:, :hidden_chunk] if hidden_chunk > 0 else None
        )
        out_embeddings = cutmix_embeddings[:, hidden_chunk:]
        interpolated_images = model.forward_embeddings(
            inputs, hidden_embeddings, out_embeddings
        )

    grid_img = make_grid(
        rearrange(
            interpolated_images,
            "(b h w) c -> b c h w",
            b=num_images,
            h=im_size[0],
            w=im_size[1],
        ),
        nrow=num_images,
    )
    grid_img = grid_img.permute(1, 2, 0).cpu().numpy()
    return grid_img
