from os import name
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image


@torch.no_grad()
def load_image(path: str) -> torch.Tensor:
    image = Image.open(path)
    image = (
        torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    )
    return image


@torch.no_grad()
def display_image(
    image: torch.Tensor,
    title: str = None,
    grayscale: bool = False,
    add_resolution: bool = False,
):
    fig = plt.figure()

    if not grayscale:
        plt.imshow(image.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0))
    else:
        plt.imshow(
            image.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0).mean(axis=2),
            cmap="gray",
        )

    img_name = ""
    if title is not None:
        img_name = title
    if add_resolution:
        img_name += f" ({'x'.join(map(str, list(image.shape)))})"
    plt.title(img_name)
    plt.axis("off")
    plt.show()
    plt.close(fig)


def display_row(
    images: list[torch.Tensor],
    title: str = None,
    subtitles: list[str] = None,
    grayscale: bool = False,
    add_resolution: bool = False,
):
    fig = plt.figure(figsize=(6 * len(images), 6))
    if title is not None:
        plt.suptitle(title)

    for i, image in enumerate(images):
        plt.subplot(1, len(images), i + 1)
        if not grayscale:
            plt.imshow(image.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0))
        else:
            plt.imshow(
                image.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0).mean(axis=2),
                cmap="gray",
            )

        img_name = ""
        if subtitles is not None:
            img_name += subtitles[i]
        if add_resolution:
            img_name += f" ({'x'.join(map(str, list(image.shape)))})"
        plt.title(img_name)
        plt.axis("off")

    plt.show()
    plt.close(fig)
