#!/usr/bin/env python3
import os
import sys
from PIL import Image

def concat_images(iteration, root_dir):
    it = f"it{iteration}-0.png"
    full_image_path = os.path.join(root_dir, f"{iteration}-full.png")
    preview_image_path = os.path.join(root_dir, f"{iteration}-preview.png")
    
    paths = [os.path.join(root_dir, d, 'save', it) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
    images = [Image.open(p) for p in paths if os.path.exists(p)]

    # Check if any images were found
    if not images:
        print("No images found for iteration", iteration)
        return

    # Create full concatenated image
    total_height = sum(img.height for img in images)
    max_width = max(img.width for img in images)
    new_img_full = Image.new('RGB', (max_width, total_height))
    y_offset = 0
    for img in images:
        new_img_full.paste(img, (0, y_offset))
        y_offset += img.height
    new_img_full.save(full_image_path)
    print("Saved concatenated image as", full_image_path)

    # Create preview image (grid 4x2 of first fifth)
    crop_width = max_width // 5
    grid_img = Image.new('RGB', (crop_width * 4, img.height * 2))  # assuming all images are the same height
    x_offset = 0
    y_offset = 0
    count = 0
    for img in images:
        cropped_img = img.crop((0, 0, crop_width, img.height))
        grid_img.paste(cropped_img, (x_offset, y_offset))
        x_offset += crop_width
        if x_offset >= crop_width * 4:
            x_offset = 0
            y_offset += img.height
        count += 1
        if count == 8:  # Only place 8 images in the grid
            break

    grid_img.save(preview_image_path)
    print("Saved preview grid image as", preview_image_path)

if __name__ == '__main__':
    if len(sys.argv) != 3:
        print("Usage: python concat_images.py <root_path> <iteration_number>")
        sys.exit(1)
    
    iteration_num = sys.argv[2]
    root_path = sys.argv[1]
    concat_images(iteration_num, root_path)
