from PIL import Image
from tqdm import tqdm
import os

# Configuration
gap_ratio = 0.
subimage_idx = [3, 0, 4]
it_idx = [200, 500, 1000, 2000, 3600, 5000, 8000]
num_subs = 6
zoom_factor = 1.5  # Zoom factor for the zoom level
x_offset = -20  # X offset for the zoom center
y_offset = -40  # Y offset for the zoom center
directory = 'outputs/mass_runs/20240503_175509_Cat_overfitted_default/ours-inversion-2/A_DSLR_photo_of_a_white_fluffy_cat@20240503-175516/'  # Update this path as necessary

def zoom_image(image, zoom_factor, x_offset, y_offset):
    width, height = image.size
    new_width, new_height = int(width * zoom_factor), int(height * zoom_factor)
    zoomed_image = image.resize((new_width, new_height), Image.LANCZOS)
    
    center_x = new_width // 2 + x_offset
    center_y = new_height // 2 + y_offset
    
    left = center_x - width // 2
    top = center_y - height // 2
    right = center_x + width // 2
    bottom = center_y + height // 2
    
    return zoomed_image.crop((left, top, right, bottom))

def process_images(directory, it_idx, subimage_idx, gap_ratio, num_subs, zoom_factor, x_offset, y_offset):
    # List to store final segments
    final_images = []
    
    # Process each specified image index
    for i, idx in tqdm(enumerate(it_idx)):
        file_path = os.path.join(directory, f'save/it{idx}-0.png')
        with Image.open(file_path) as img:
            img_width, img_height = img.size
            sub_img_width = img_width // num_subs
            
            # Crop and collect sub-images according to subimage_idx
            sub_images = []
            for sub_idx in subimage_idx:
                left = sub_idx * sub_img_width
                right = left + sub_img_width
                sub_image = img.crop((left, 0, right, img_height))
                
                # Apply zoom and offset to the sub-image
                zoomed_sub_image = zoom_image(sub_image, zoom_factor, x_offset, y_offset)
                sub_images.append(zoomed_sub_image)
            
            # Concatenate sub-images vertically with a transparent gap
            gap_height = int(img_height * gap_ratio)
            gap_image = Image.new('RGBA', (sub_img_width, gap_height), (255, 255, 255, 0))
            vertical_concat = [sub_image for sub_image in sub_images]
            for j in range(len(vertical_concat) - 1):
                vertical_concat.insert(j * 2 + 1, gap_image)
            
            # Create a new image to hold the concatenated result
            total_height = sum(img.height for img in vertical_concat)
            combined_image = Image.new('RGBA', (sub_img_width, total_height))
            y_offset_paste = 0
            for img in vertical_concat:
                combined_image.paste(img, (0, y_offset_paste))
                y_offset_paste += img.height
            
            final_images.append(combined_image)
    
    # Concatenate all images horizontally
    total_height = final_images[0].height
    gap_image = Image.new('RGBA', (int(gap_ratio * total_height), total_height), (255, 255, 255, 0))
    for j in range(len(final_images) - 1):
        final_images.insert(j * 2 + 1, gap_image)
    
    total_width = sum(img.width for img in final_images)
    final_result = Image.new('RGBA', (total_width, total_height))
    x_offset_paste = 0
    for img in final_images:
        final_result.paste(img, (x_offset_paste, 0))
        x_offset_paste += img.width
    
    out_dir = os.path.join(directory, 'evolution_image.png')
    print("\n Output saved:\n", out_dir)

    final_result.save(out_dir)

# Run the function with the given configuration
process_images(directory, it_idx, subimage_idx, gap_ratio, num_subs, zoom_factor, x_offset, y_offset)
