import torch
import numpy as np
from PIL import Image
import os
import yaml


def load_image(img: Image, device, resize_dims=(512, 512)):
    img = img.convert("RGB")
    img = img.resize(resize_dims)
    img = 2.0 * np.array(img).astype(np.float32) / 255.0 - 1.0
    img = torch.from_numpy(img).unsqueeze(0).permute(0, 3, 1, 2).to(device)
    return img


def load_image_batch(imags, device, resize_dims=(512, 512)):
    imgs_batch = []
    for img in imags:
        img = load_image(img, device, resize_dims)
        imgs_batch.append(img)
    imgs_batch = torch.cat(imgs_batch, dim=0)
    return imgs_batch


def output_image(img):
    img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    img = (img + 1.0) / 2.0
    img = np.clip(img, 0.0, 1.0)
    img = (255 * img).astype(np.uint8)
    return Image.fromarray(img)


def output_image_batch(imgs):
    imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
    imgs = (imgs + 1.0) / 2.0
    imgs = np.clip(imgs, 0.0, 1.0)
    imgs = (255 * imgs).astype(np.uint8)
    img_list = [Image.fromarray(img) for img in imgs]
    return img_list


def display_alongside(
    img_list, resize_dims=(512, 512), padding=0, frame_color=(255, 255, 255)
):
    padded_width = resize_dims[0] + 2 * padding
    padded_height = resize_dims[1] + 2 * padding
    res = Image.new("RGB", (padded_width * len(img_list), padded_height), frame_color)
    for i, img in enumerate(img_list):
        x_offset = i * padded_width + padding
        y_offset = padding
        img_resized = img.resize(resize_dims)
        res.paste(img_resized, (x_offset, y_offset))
    return res


def display_in_two_rows(
    img_list, resize_dims=(512, 512), padding=5, frame_color=(255, 255, 255)
):
    num_images = len(img_list)
    num_images_per_row = (num_images + 1) // 2
    padded_width = resize_dims[0] + 2 * padding
    padded_height = resize_dims[1] + 2 * padding
    total_width = padded_width * num_images_per_row
    total_height = padded_height * 2
    res = Image.new("RGB", (total_width, total_height), frame_color)
    for i, img in enumerate(img_list):
        row = i // num_images_per_row
        col = i % num_images_per_row
        x_offset = col * padded_width + padding
        y_offset = row * padded_height + padding
        img_resized = img.resize(resize_dims)
        res.paste(img_resized, (x_offset, y_offset))
    return res
