from PIL import Image

def aspect_resize_pip_image(image, max_size):
    # Resize the image while maintaining aspect ratio and fitting within the target size
    aspect_ratio = image.width / image.height
    new_width, new_height = max_size

    if aspect_ratio > 1:
        new_height = int(max_size[0] / aspect_ratio)
    else:
        new_width = int(max_size[1] * aspect_ratio)

    return image.resize((new_width, new_height), Image.ANTIALIAS)

def pip_images_make_grid(pip_images, cols, rows=None, cell_size=256, fill=(128, 128, 128)):
    total_images = len(pip_images)
    if rows is None:
        rows = -(-total_images // cols) ## cal minimal rows
    num_cells = rows * cols
    if total_images > num_cells:
        print(f"Image Grid: {total_images-num_cells} images is skipped")
    pip_images = pip_images[:num_cells]

    # Create a new grid image
    cell_width = cell_size
    cell_height = cell_size
    grid_width = cell_width * cols
    grid_height = cell_height * rows
    grid = Image.new('RGB', (grid_width, grid_height), fill)

    for i, image in enumerate(pip_images):
        row = i // cols
        col = i % cols
        resized_image = aspect_resize_pip_image(image, (cell_width, cell_height))
        paste_x = col * cell_width + (cell_width - resized_image.width) // 2
        paste_y = row * cell_height + (cell_height - resized_image.height) // 2
        grid.paste(resized_image, (paste_x, paste_y))
    return grid


def load_small_modules(device):
    from libs.clip import FrozenCLIPEmbedder
    from libs.caption_decoder import CaptionDecoder
    import libs.autoencoder
    import clip

    caption_decoder = CaptionDecoder(
        device=device,
        pretrained_path="models/caption_decoder.pth",
        hidden_dim=64,
        tokenizer_path = "models/gpt2"
    )
    autoencoder = libs.autoencoder.get_model(pretrained_path='models/autoencoder_kl.pth',).to(device)
    clip_text_model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device)
    clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", jit=False)
    clip_img_model.to(device).eval().requires_grad_(False)
    return autoencoder, caption_decoder, clip_text_model, clip_img_model