import argparse
from pathlib import Path
import re
from typing import Dict, List, Tuple, Union
import numpy as np
from PIL import Image
import torch
from torchvision.utils import make_grid
import torchvision.transforms as transforms

# Define trajectory specifications as tuples of (method, seed, sample_idx)
# Use "break" as method to insert a gap
TRAJECTORY_SPECS = [
    # ("DDPM", 0, 0),
    # ("ours", 0, 0),
    # ("break", None, None),
    # ("DDPM", 0, 1),
    # ("ours", 0, 1),
    # ("break", None, None),
    # ("DDPM", 0, 2),
    # ("ours", 0, 2),
    # ("break", None, None),
    # ("DDPM", 0, 3),
    # ("ours", 0, 3),
    # ("break", None, None),
    ("DDPM", 0, 4),
    ("ours", 0, 4),
    # ("break", None, None),
    # ("DDPM", 0, 5),
    # ("ours", 0, 5),
]

def get_dataset_from_path(path: Path) -> str:
    """Extract dataset name from experiment path."""
    path_str = str(path)
    datasets = {
        "fashion_mnist": "FMNIST",
        "cifar10": "CIFAR10",
        "celeba_hq": "CelebA-HQ",
        "afhq": "AFHQv2",
        "mnist": "MNIST",
    }

    for key, value in datasets.items():
        if key in path_str.lower():
            return value
    return None

def get_method_from_path(path: Path) -> str:
    """Extract method name from experiment path."""
    path_str = str(path)
    if "wiener_" in path_str:
        return "wiener"
    elif "kamb_" in path_str:
        return "kamb"
    elif "niedoba_" in path_str:
        return "niedoba"
    elif "ours_" in path_str:
        return "ours"
    elif "optimal_" in path_str:
        return "optimal"
    elif "unet_" in path_str:
        return "unet"
    return None

def get_seed_from_path(path: Path) -> int:
    """Extract seed number from experiment path."""
    path_str = str(path)
    match = re.search(r'seed(\d+)', path_str)
    if match:
        return int(match.group(1))
    return None

def find_experiment_dir_for_method_and_seed(results_dir: Path, dataset: str, method: str, seed: int) -> Path:
    """Find the experiment directory for a specific method, dataset, and seed."""
    matching_dirs = []
    for exp_dir in results_dir.iterdir():
        if not exp_dir.is_dir():
            continue
            
        exp_dataset = get_dataset_from_path(exp_dir)
        exp_method = get_method_from_path(exp_dir)
        exp_seed = get_seed_from_path(exp_dir)
        
        if exp_dataset == dataset and exp_method == method and exp_seed == seed:
            matching_dirs.append(exp_dir)
    
    return max(matching_dirs) if matching_dirs else None

def load_image(image_path: Path) -> torch.Tensor:
    """Load and normalize image to [-1, 1] range."""
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((64, 64)),  # Resize all images to 64x64
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    tensor = transform(img)
    return tensor.unsqueeze(0)  # Add batch dimension to match gap tensor shape

def create_gap_tensor(height: int, width: int, channels: int = 3) -> torch.Tensor:
    """Create a tensor representing a gap between trajectories."""
    return torch.zeros((1, 1, channels, height, width))  # Add both batch and timestep dimensions

def get_available_timesteps(directory: Path, pattern: str) -> List[int]:
    """Find all available timesteps from image filenames."""
    timesteps = set()
    for img_path in directory.glob(pattern):
        match = re.search(r'step_(\d+)_', img_path.name)
        if match:
            timesteps.add(int(match.group(1)))
    return sorted(list(timesteps))

def generate_trajectory_grid(
    results_dir: str = "experiment_results",
    dataset: str = "CIFAR10",
    trajectory_specs: List[Tuple[Union[str, None], Union[int, None], Union[int, None]]] = TRAJECTORY_SPECS,
    output_file: str = "trajectory_grid.png"
):
    """Generate a grid of image trajectories showing progression over time.
    
    Args:
        results_dir: Directory containing experiment results
        dataset: Dataset name to generate grid for
        trajectory_specs: List of (method, seed, sample_idx) tuples
        output_file: Output file path for the grid image
    """
    results_path = Path(results_dir)
    
    # Dictionary to store images for each trajectory
    trajectory_images = []
    
    # First pass: find all available timesteps
    all_timesteps = set()
    for method, seed, sample_idx in trajectory_specs:
        if method == "break":
            continue
            
        method_map = {
            "DDPM": "unet",
            "ours": "ours",
            "oursNN": "ours"
        }
        dir_method = method_map.get(method)
        if dir_method is None:
            continue
            
        exp_dir = find_experiment_dir_for_method_and_seed(results_path, dataset, dir_method, seed)
        if exp_dir is None:
            continue
            
        individual_dir = exp_dir / "individual_images"
        if not individual_dir.exists():
            continue

        # For DDPM, we need to use unet_trajectory_x0
        image_type = "unet_trajectory_x0" if method == "DDPM" else f"{dir_method}_x0"
        pattern = f"step_*_{image_type}_sample_{sample_idx:02d}.png"
        timesteps = get_available_timesteps(individual_dir, pattern)
        all_timesteps.update(timesteps)
    
    timesteps = sorted(list(all_timesteps))
    if not timesteps:
        raise ValueError("No timesteps found in any of the trajectories")
    
    print(f"Found timesteps: {timesteps}")
    
    # Process each trajectory specification
    for method, seed, sample_idx in trajectory_specs:
        if method == "break":
            # Add a gap - create a full-size gap tensor
            gap = create_gap_tensor(64, 64)  # Full size to match images
            # Repeat the gap for all timesteps
            gap = gap.repeat(len(timesteps), 1, 1, 1, 1)
            trajectory_images.append(gap)
            continue
            
        # Get the method name for directory lookup
        method_map = {
            "DDPM": "unet",
            "ours": "ours",
            "oursNN": "ours"
        }
        dir_method = method_map.get(method)
        if dir_method is None:
            print(f"Warning: Unknown method {method}")
            continue
            
        # Find the experiment directory
        exp_dir = find_experiment_dir_for_method_and_seed(results_path, dataset, dir_method, seed)
        if exp_dir is None:
            print(f"Warning: No experiment directory found for {method} with seed {seed}")
            continue
            
        individual_dir = exp_dir / "individual_images"
        if not individual_dir.exists():
            print(f"Warning: No individual_images directory found in {exp_dir}")
            continue

        # Collect images for all timesteps
        trajectory = []
        for t in timesteps:
            # For DDPM, we need to use unet_trajectory_x0
            image_type = "unet_trajectory_x0" if method == "DDPM" else f"{dir_method}_x0"
            pattern = f"step_{t:04d}_{image_type}_sample_{sample_idx:02d}.png"
            
            found = False
            for img_path in individual_dir.glob(pattern):
                trajectory.append(load_image(img_path))
                found = True
                break
                
            if not found:
                print(f"Warning: No image found for {method} at timestep {t}")
                # Add a blank image to maintain alignment
                trajectory.append(create_gap_tensor(64, 64))
        
        if trajectory:
            # Stack images horizontally for this trajectory
            trajectory_tensor = torch.stack(trajectory)  # Shape: [timesteps, 1, channels, height, width]
            trajectory_images.append(trajectory_tensor)
    
    if not trajectory_images:
        raise ValueError("No images found for the specified parameters")
    
    # Stack all trajectories vertically
    final_grid = torch.cat(trajectory_images, dim=0)  # Shape: [total_timesteps, 1, channels, height, width]
    
    # Reshape for make_grid: [batch, channels, height, width]
    final_grid = final_grid.squeeze(1)  # Remove the extra dimension
    
    # Create the grid
    grid = make_grid(final_grid, nrow=len(timesteps), padding=0, normalize=True)
    
    # Save the grid
    transforms.ToPILImage()(grid).save(output_file)
    print(f"Trajectory grid saved to {output_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate trajectory grid from experiment results")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["CIFAR10", "CelebA-HQ", "AFHQv2", "MNIST", "FMNIST"],
        default="CelebA-HQ",
        help="Dataset to generate grid for"
    )
    args = parser.parse_args()
    
    generate_trajectory_grid(
        dataset=args.dataset,
        output_file=f"paper/images/{args.dataset}_trajectory_grid.png"
    ) 