import os
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import re

def parse_case_numbers(case_string):
    """Parse case numbers from string (ranges or comma-separated)."""
    if not case_string:
        return None
    
    case_numbers = set()
    parts = case_string.split(',')
    
    for part in parts:
        part = part.strip()
        if '-' in part:
            # Handle range like "1-5"
            start, end = part.split('-', 1)
            start_num = int(start.strip())
            end_num = int(end.strip())
            for i in range(start_num, end_num + 1):
                case_numbers.add(f"{i}")
        else:
            # Handle single number
            num = int(part.strip())
            case_numbers.add(f"{num}")
    
    return sorted(case_numbers)

def find_method_cases(base_path, case_number=None):
    """Find all method directories and their cases."""
    base_path = Path(base_path)
    methods = {}
    
    for method_dir in base_path.iterdir():
        if method_dir.is_dir():
            cases = []
            for case_dir in method_dir.iterdir():
                if case_dir.is_dir() and case_dir.name.startswith('Case'):
                    if case_number is None:
                        cases.append(case_dir)
                    elif case_dir.name == f'Case{case_number}':
                        cases.append(case_dir)
            if cases:
                methods[method_dir.name] = sorted(cases)
    
    return methods

def collect_images(methods, image_types):
    """Collect images for each method and image type."""
    collected = {}
    
    for method_name, cases in methods.items():
        collected[method_name] = {}
        
        for img_type in image_types:
            images = []
            for case_dir in cases:
                if img_type == 'kp_overlay_Y.png':
                    # Use regex to find kp_overlay_Y files with optional case number
                    pattern = re.compile(r'kp_overlay_Y(_case_\d+)?\.png')
                    for file_path in case_dir.iterdir():
                        if file_path.is_file() and pattern.match(file_path.name):
                            images.append(file_path)
                            break  # Take first match
                elif img_type == 'accuracy_histogram.png':
                    # Use regex to find accuracy_histogram files with optional case number
                    pattern = re.compile(r'accuracy_histogram(_case_\d+)?\.png')
                    for file_path in case_dir.iterdir():
                        if file_path.is_file() and pattern.match(file_path.name):
                            images.append(file_path)
                            break  # Take first match
                elif img_type == 'acc_hist.png':
                    # New histogram type from metrics analysis
                    img_path = case_dir / img_type
                    if img_path.exists():
                        images.append(img_path)
                else:
                    img_path = case_dir / img_type
                    if img_path.exists():
                        images.append(img_path)
            collected[method_name][img_type] = images
    
    return collected

def crop_image(img_path, target_width):
    """Crop image to target width in x direction."""
    img = Image.open(img_path)
    width, height = img.size
    
    if width > target_width:
        # Crop from left edge (0 to target_width)
        img = img.crop((0, 0, target_width, height))
    
    return img


def create_collage(collected_images, output_path, image_types, crop_width=1797, n_cols=4):
    """Create collage with methods as columns and image types as rows."""
    methods = list(collected_images.keys())
    
    if not methods:
        print("No methods found!")
        return
    
    # Sort methods by number prefix if present
    def sort_key(method):
        match = re.match(r'^(\d+)', method)
        return int(match.group(1)) if match else float('inf')
    
    methods = sorted(methods, key=sort_key)
    
    # Calculate figure size and grid layout
    n_methods = len(methods)
    n_image_types = len(image_types)
    
    # Special case: if only 1 image type, arrange methods in n_cols-column grid
    if n_image_types == 1:
        n_cols = min(n_cols, n_methods)  # Max n_cols columns
        n_rows = (n_methods + n_cols - 1) // n_cols  # Ceiling division
        fig = plt.figure(figsize=(5 * n_cols, 4 * n_rows))
        gs = gridspec.GridSpec(n_rows, n_cols, figure=fig)
        
        img_type = image_types[0]
        for i, method in enumerate(methods):
            row = i // n_cols
            col = i % n_cols
            ax = fig.add_subplot(gs[row, col])
            
            images = collected_images[method].get(img_type, [])
            if images:
                img_path = images[0]
                
                if img_type == 'kp_overlay_Y.png' and crop_width > 0:
                    img = crop_image(img_path, crop_width)
                else:
                    img = Image.open(img_path)
                
                # if 'sinr' not in method.lower():
                #     width, height = img.size
                #     orig_crop = np.array([172, 113, 362, 366])  # Original crop values
                #     orig_shapes = np.array([512, 512, 512, 512])
                #     orig_padding = np.array([-50, -50, 50, 50])
                #     # orig_padding = np.array([0, 0, 0, 0])
                #     new_im_size = np.array([width, height, width, height])
                #     new_crop = (orig_crop + orig_padding) / orig_shapes * new_im_size
                #     img = img.crop(new_crop)
                
                ax.imshow(img)
                ax.axis('off')
                
                # Strip number prefix from method name for title
                display_name = re.sub(r'^\d+[-_]?', '', method)
                ax.set_title(display_name, fontsize=12, fontweight='bold')
            else:
                ax.text(0.5, 0.5, 'No image', ha='center', va='center', 
                       transform=ax.transAxes, fontsize=10)
                ax.axis('off')
    else:
        # Original layout: methods as columns, image types as rows
        fig = plt.figure(figsize=(5 * n_methods, 4 * n_image_types))
        gs = gridspec.GridSpec(n_image_types, n_methods, figure=fig)
        
        for row, img_type in enumerate(image_types):
            for col, method in enumerate(methods):
                ax = fig.add_subplot(gs[row, col])
                
                images = collected_images[method].get(img_type, [])
                if images:
                    # Use first available image for each method/type combination
                    img_path = images[0]
                    
                    if img_type == 'kp_overlay_Y.png' and crop_width > 0:
                        img = crop_image(img_path, crop_width)
                    else:
                        img = Image.open(img_path)
                    
                    ax.imshow(img)
                    ax.axis('off')
                    
                    # Add titles
                    if row == 0:
                        # Strip number prefix from method name for title
                        display_name = re.sub(r'^\d+[-_]?', '', method)
                        ax.set_title(display_name, fontsize=12, fontweight='bold')
                    if col == 0:
                        ylabel = img_type.replace('.png', '').replace('_', ' ').title()
                        ax.set_ylabel(ylabel, fontsize=10, fontweight='bold')
                else:
                    ax.text(0.5, 0.5, 'No image', ha='center', va='center', 
                           transform=ax.transAxes, fontsize=10)
                    ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Collage saved to: {output_path}")

def main():
    parser = argparse.ArgumentParser(description='Create collage of images from method results')
    parser.add_argument('base_path', type=str, help='Base path containing method directories')
    parser.add_argument('--case', '-k', type=str, default=None,
                       help='Case numbers: single (01), range (1-5), or comma-separated (01,03,05)')
    parser.add_argument('--output', '-o', type=str, default='collage.png', 
                       help='Output path for collage image')
    parser.add_argument('--crop_width', '-c', type=int, default=0,
                       help='Width to crop kp_overlay images to')
    parser.add_argument('--columns', '-n', type=int, default=4,
                       help='Number of columns for single image type layout')
    # crop_width 1797 is the default width for kp_overlay images
    
    args = parser.parse_args()
    
    # Validate base path
    if not os.path.exists(args.base_path):
        print(f"Error: Base path '{args.base_path}' does not exist!")
        return
    
    # Parse case numbers
    case_numbers = parse_case_numbers(args.case)
    if case_numbers is None:
        case_numbers = [None]  # Process all cases if no specific cases given
    
    # Image types to look for (updated to include new histogram type)
    # image_types = ['kp_overlay_Y.png', 'acc_hist.png', 'overlay_init_Y.png']
    image_types = ['kp_overlay_Y.png']
    
    # Create collage for each case
    for case_num in case_numbers:
        # Find methods and cases
        methods = find_method_cases(args.base_path, case_num)
        if not methods:
            case_info = f" for case {case_num}" if case_num else ""
            print(f"No method directories with cases found{case_info}!")
            continue
        
        case_info = f" (Case {case_num})" if case_num else ""
        print(f"Found methods{case_info}: {list(methods.keys())}")
        for method, cases in methods.items():
            print(f"  {method}: {len(cases)} cases")
        
        # Collect images
        collected = collect_images(methods, image_types)
        
        # Create output path with case number
        output_path = Path(args.output)
        if case_num:
            stem = output_path.stem
            suffix = output_path.suffix
            output_path = output_path.parent / f"{stem}_case_{case_num}{suffix}"
        
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Create collage
        create_collage(collected, output_path, image_types, args.crop_width, args.columns)

if __name__ == "__main__":
    main()
