import math
from typing import Iterable, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader

from vis_datasets.wrappers.data_sample import DataSample


# def imshow(img: torch.Tensor) -> None:
#     print("img:", img)
#     img = img / 2 + 0.5 # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()


# def show_images(
#     dataloader: DataLoader,
#     class_labels: tuple[str, ...],
#     n_images: int = 4,
# ) -> None:
#     dataiter = iter(dataloader)
#     batch = next(dataiter)
#     images, labels = batch.x, batch.y
#     images = images[:n_images]
#     labels = labels[:n_images]

#     imshow(torchvision.utils.make_grid(images))
#     print(" ".join(class_labels[i] for i in labels))


import matplotlib.pyplot as plt
import numpy as np
import torch


def _img_is_color(img: np.ndarray) -> bool:

    if len(img.shape) == 3:
        # Check the color channels to see if they're all the same.
        c1, c2, c3 = img[:, : , 0], img[:, :, 1], img[:, :, 2]
        if (c1 == c2).all() and (c2 == c3).all():
            return True

    return False

def show_images(
    images: Union[
        DataLoader[DataSample],
        list[torch.Tensor],
        list[np.ndarray],
    ],
    n_images: int = 8,
    titles: Optional[Iterable[str]] = None,
    cmaps: Optional[list] = None,
    show_grid: bool = False,
    n_cols: int = 8,
    figsize: tuple[int, int] = (20, 10),
    title_fontsize: int = 30,
    show_titles: bool = True,
) -> plt.Figure:
    '''
    Shows a grid of images with optional labels.
    Roughly inspired by:
    https://stackoverflow.com/questions/41793931/plotting-images-side-by-side-using-matplotlib

    Parameters:
    ----------
    images: pytorch DataLoader, list of pytorch Tensors or numpy ndarrays
        List of the images to be displayed.
    list_titles: list or None
        Optional list of titles to be shown for each image.
    list_cmaps: list or None
        Optional list of cmap values for each image. If None, then cmap will be
        automatically inferred.
    grid: boolean
        If True, show a grid over each image
    num_cols: int
        Number of columns to show.
    figsize: tuple of width, height
        Value to be passed to pyplot.figure()
    title_fontsize: int
        Value to be passed to set_title().
    '''

    if isinstance(images, DataLoader):
        dataiter = iter(images)
        batch = next(dataiter)
        # image_list, labels = batch.x, batch.labels
        image_list = batch.input
        labels = [str(yi.item()) for yi in batch.target]
        # TODO: make sure the batch has enough data available
        image_list = image_list[:n_images].detach().cpu().numpy()
        image_list = np.transpose(image_list, (0, 2, 3, 1))
        # if image_list.dtype == float and image_list.max() > 1.0:
        #     image_list = image_list.astype(int)
        # image_list = image_list.astype(int)
        if titles is None:
            # titles = [label.item() for label in labels[:n_images]]
            titles = labels
    else:
        raise NotImplementedError

    n_cols = min(n_images, n_cols)
    n_rows = math.ceil(n_images / n_cols)

    # Create a grid of subplots.
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    
    # Create list of axes for easy iteration.
    list_axes: list[plt.Axes]
    if isinstance(axes, np.ndarray):
        list_axes = list(axes.flat)
    else:
        list_axes = [axes]

    for i, (img, title, ax) in enumerate(zip(image_list, titles, list_axes)):
        cmap = (
            cmaps[i] if cmaps is not None
            else (None if _img_is_color(img) else "gray")
        )
        ax.imshow(img, cmap=cmap, vmin=0, vmax=1)
        if show_titles:
            ax.set_title(title, fontsize=title_fontsize) 
        # ax.grid(grid)
        ax.axis("off")

    for i in range(n_images, len(list_axes)):
        list_axes[i].set_visible(False)

    fig.tight_layout()
    _ = plt.show()
    return fig
