import argparse
from pathlib import Path
import re
from typing import Dict, List, Tuple
import numpy as np
from PIL import Image
import torch
from torchvision.utils import make_grid
import torchvision.transforms as transforms


# Hardcoded row order for the grid
ROW_ORDER = [
    "unet_trajectory_x0",  # DDPM
    "unet_x0",            # AnotherDDPM
    "ours_x0",            # Ours
    "wiener_x0",          # Wiener
    "kamb_x0",            # Kamb
    "niedoba_x0",         # Niedoba
    "closest_real_ours",  # nearest to ours
]

COLUMN_SPECS = [
    (0, 0),
    (0, 1),
    (0, 2),
    (0, 3),
    (0, 4),
    (0, 5),
    (1, 0),
    (1, 1),
    (1, 2),
    (1, 3),
    (1, 4),
    # (1, 5),
]

def get_dataset_from_path(path: Path) -> str:
    """Extract dataset name from experiment path."""
    path_str = str(path)
    datasets = {
        "fashion_mnist": "FMNIST",
        "cifar10": "CIFAR10",
        "celeba_hq": "CelebA-HQ",
        "afhq": "AFHQv2",
        "mnist": "MNIST",
    }

    for key, value in datasets.items():
        if key in path_str.lower():
            return value
    return None

def get_method_from_path(path: Path) -> str:
    """Extract method name from experiment path."""
    path_str = str(path)
    # Check for exact matches first to avoid partial matches
    if "wiener_" in path_str:
        return "wiener"
    elif "kamb_" in path_str:
        return "kamb"
    elif "niedoba_" in path_str:
        return "niedoba"
    elif "ours_" in path_str:
        return "ours"
    elif "optimal_" in path_str:
        return "optimal"
    elif "unet_" in path_str:
        return "unet"
    return None

def get_seed_from_path(path: Path) -> int:
    """Extract seed number from experiment path."""
    path_str = str(path)
    match = re.search(r'seed(\d+)', path_str)
    if match:
        return int(match.group(1))
    return None

def find_experiment_dir_for_method_and_seed(results_dir: Path, dataset: str, method: str, seed: int) -> Path:
    """Find the experiment directory for a specific method, dataset, and seed.
    If multiple directories match, returns the one that comes last alphabetically."""
    matching_dirs = []
    for exp_dir in results_dir.iterdir():
        if not exp_dir.is_dir():
            continue
            
        exp_dataset = get_dataset_from_path(exp_dir)
        exp_method = get_method_from_path(exp_dir)
        exp_seed = get_seed_from_path(exp_dir)
        
        if exp_dataset == dataset and exp_method == method and exp_seed == seed:
            matching_dirs.append(exp_dir)
    
    # Return the alphabetically last matching directory, or None if no matches
    return max(matching_dirs) if matching_dirs else None

def load_image(image_path: Path) -> torch.Tensor:
    """Load and normalize image to [-1, 1] range."""
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    return transform(img)

def get_image_type_for_method(method: str) -> str:
    """Get the image type name for a given method."""
    if method == "unet":
        return "unet"
    elif method == "ours":
        return "ours_x0"
    elif method == "wiener":
        return "wiener_x0"
    elif method == "kamb":
        return "kamb_x0"
    elif method == "niedoba":
        return "niedoba_x0"
    elif method == "optimal":
        return "optimal_x0"
    return None

def generate_image_grid(
    results_dir: str = "experiment_results",
    dataset: str = "CIFAR10",
    column_specs: List[Tuple[int, int]] = COLUMN_SPECS,
    timestep: int = 1000,
    output_file: str = "image_grid.png"
):
    """Generate a grid of images from different methods for comparison.
    
    Args:
        results_dir: Directory containing experiment results
        dataset: Dataset name to generate grid for
        column_specs: List of (seed, sample_idx) tuples specifying which images to use
        timestep: Timestep to use for the images
        output_file: Output file path for the grid image
    """
    results_path = Path(results_dir)
    
    # Dictionary to store images for each row and column
    row_images = {row: [] for row in ROW_ORDER}
    
    # Process each column specification
    for seed, sample_idx in column_specs:
        # Add DDPM1
        exp_dir = find_experiment_dir_for_method_and_seed(results_path, dataset, "ours", seed)
        if exp_dir is not None:
            ours_dir = exp_dir / "individual_images"
            if ours_dir.exists():
                pattern = f"step_{timestep:04d}_unet_trajectory_x0_sample_{sample_idx:02d}.png"
                for img_path in ours_dir.glob(pattern):
                    row_images["unet_trajectory_x0"].append(load_image(img_path))
                    break
        
        # For each method in ROW_ORDER (except closest_real_ours which is handled separately)
        for row_name in ROW_ORDER:
            if row_name == "closest_real_ours" or row_name == "unet_trajectory_x0":
                continue
                
            # Get the method name from the row name
            method = row_name.split("_")[0]  # e.g., "unet_trajectory_x0" -> "unet"
            
            # Find the experiment directory for this method and seed
            exp_dir = find_experiment_dir_for_method_and_seed(results_path, dataset, method, seed)
            if exp_dir is None:
                print(f"Warning: No experiment directory found for {method} with seed {seed}")
                continue
                
            individual_dir = exp_dir / "individual_images"
            if not individual_dir.exists():
                print(f"Warning: No individual_images directory found in {exp_dir}")
                continue

            # Find the image for this method at the specified timestep
            # print(f"Looking for {method} on {dataset} with seed {seed} in {individual_dir}")
            pattern = f"step_{timestep:04d}_ours_x0_sample_{sample_idx:02d}.png"
            for img_path in individual_dir.glob(pattern):
                row_images[row_name].append(load_image(img_path))
                break  # Only take the first matching image
        
        # Add closest real images from the "ours" method for this seed
        exp_dir = find_experiment_dir_for_method_and_seed(results_path, dataset, "ours", seed)
        if exp_dir is not None:
            ours_dir = exp_dir / "individual_images"
            if ours_dir.exists():
                pattern = f"step_{timestep:04d}_closest_real_ours_sample_{sample_idx:02d}.png"
                for img_path in ours_dir.glob(pattern):
                    row_images["closest_real_ours"].append(load_image(img_path))
                    break
    
    # Create grid
    grid_rows = []
    for row_name in ROW_ORDER:
        if row_images[row_name]:
            # Stack images horizontally for this row
            row_tensor = torch.stack(row_images[row_name])
            grid_rows.append(row_tensor)
    
    if not grid_rows:
        raise ValueError("No images found for the specified parameters")
    
    # Stack all rows vertically
    final_grid = torch.cat(grid_rows, dim=0)
    
    # Create the grid
    grid = make_grid(final_grid, nrow=len(column_specs), padding=0, normalize=True)
    
    # Save the grid
    transforms.ToPILImage()(grid).save(output_file)
    print(f"Grid image saved to {output_file}")
    print(f"Found images for rows: {[row for row in ROW_ORDER if row_images[row]]}")
    # print(f"Looking for images with patterns:")
    # for row_name in ROW_ORDER:
    #     if row_name == "closest_real_ours":
    #         print(f"  closest_real_ours: step_{timestep:04d}_closest_real_ours_sample_*.png")
    #     else:
    #         print(f"  {row_name}: step_{timestep:04d}_{row_name}_sample_*.png")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate image grid from ablation experiment results")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["CIFAR10", "CelebA-HQ", "AFHQv2", "MNIST", "FMNIST"],
        default="CelebA-HQ",
        help="Dataset to generate grid for"
    )
    parser.add_argument(
        "--timestep",
        type=int,
        default=1000,
        help="Timestep to use for the images"
    )
    parser.add_argument(
        "--column-specs",
        type=str,
        default=None,
        help="Comma-separated list of seed,sample_idx pairs (e.g., '33,0,33,1,13,0')"
    )
    args = parser.parse_args()
    
    # Parse column specifications if provided
    column_specs = COLUMN_SPECS
    if args.column_specs:
        try:
            values = [int(x) for x in args.column_specs.split(',')]
            column_specs = [(values[i], values[i+1]) for i in range(0, len(values), 2)]
        except ValueError:
            print("Error: column-specs must be comma-separated integers")
            exit(1)
    
    generate_image_grid(
        dataset=args.dataset,
        column_specs=column_specs,
        timestep=args.timestep,
        output_file=f"paper/images/{args.dataset}_ablation_grid.png"
    )
    