import matplotlib.pyplot as plt
import plotly.offline as pyo
import torch as t
from einops import rearrange
from matplotlib.figure import Figure as pltFigure

from auto_encoder import device
from data.ae_data import get_image_dataloader


def chart_reconstructions(
    model, dataset_name: str = "ylecun/mnist", num_images: int = 10
) -> pltFigure:
    """
    Display original and reconstructed images side by side.

    Args:
        model (Union[MoAutoEncoder, GatedAutoEncoder]): The trained autoencoder model.
        dataset_name (str): The name of the dataset to use.
        num_images (int): The number of images to display.
    """
    # Get the dataset
    _train_loader, eval_loader = get_image_dataloader(
        batch_size=num_images, dataset_name=dataset_name
    )

    # Get a batch of images
    for batch in eval_loader:
        try:
            images = batch["image"]
        except:
            images = batch["img"]
        images: t.Tensor = images.float().to(device)
        images = rearrange(
            images, "batch channels height width -> batch 1 (channels height width)"
        )
        x = (images - images.mean()) / images.std()
        x = t.relu(x)
        break

    # Get the reconstructions
    model.eval()
    with t.no_grad():
        x_hat = model(x, output_intermediate_activations=False)

    # Plot the images and their reconstructions
    fig, axes = plt.subplots(nrows=num_images, ncols=2, figsize=(10, num_images * 5))

    for i in range(num_images):
        orig_img = x[i].cpu().numpy().reshape(28, 28)
        recon_img = x_hat[i].cpu().numpy().reshape(28, 28)

        # Original image
        axes[i, 0].imshow(orig_img, cmap="gray")
        axes[i, 0].set_title("Original")
        axes[i, 0].axis("off")

        # Reconstructed image
        axes[i, 1].imshow(recon_img, cmap="gray")
        axes[i, 1].set_title("Reconstructed")
        axes[i, 1].axis("off")

    plt.tight_layout()
    plt.show()

    return fig
