import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from collections import defaultdict
import re

# --- Configuration ---
VISUALIZATIONS_DIR = 'visualizations_iclr_appendix'
OUTPUT_DIR = 'comparison_grids_iclr_appendix_rasterized_pdf'
# Define the order for your columns (models) and rows (layers)
# This ensures the grid is always in the same, predictable order.
# The script will automatically find the full directory names.
MODEL_ORDER_HINT = [
    "Qwen_Qwen2_5-1_5B-Instruct",
    # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_300",
    "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_600",
    # # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_900",
    "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_1200",
    # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_1500",
    "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base_",
    # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Legal_Move",
]
LAYER_ORDER = sorted([4, 8, 12, 13, 14, 15, 16, 17, 18, 20, 24])
# LAYER_ORDER = sorted([4, 8, 12, 16, 20, 24])
# LAYER_ORDER = sorted([12,16])

# ICLR Appendix
# --- Configuration ---
# VISUALIZATIONS_DIR = 'visualizations'
# OUTPUT_DIR = 'comparison_grids_appendix'
# # Define the order for your columns (models) and rows (layers)
# # This ensures the grid is always in the same, predictable order.
# # The script will automatically find the full directory names.
# MODEL_ORDER_HINT = [
#     "Qwen_Qwen2_5-1_5B-Instruct",
#     # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_300",
#     "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_600",
#     # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_900",
#     "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_1200",
#     # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__Ckpt_1500",
#     "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base_",
#     # "Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Legal_Move",
# ]
# LAYER_ORDER = sorted([4, 8, 12, 13, 14, 15, 16, 17, 18, 20, 24])
# LAYER_ORDER = sorted([4, 8, 12, 16, 20, 24])
# LAYER_ORDER = sorted([12,16])


def format_model_name(model_dir: str) -> str:
    """Format raw model directory name into a cleaner, multi-line title.

    Strategy:
    - Remove trailing underscores.
    - Replace known verbose tokens with concise forms.
    - Split on double underscores '__' to produce logical line breaks.
    - Within each segment, replace single underscores and hyphens with spaces.
    - Collapse multiple spaces and trim.
    - Join segments with '\n' so matplotlib renders multi-line titles.
    """
    original = model_dir
    # Trim trailing underscores (multiple) for display
    model_dir = re.sub(r'_+$', '', model_dir)

    # Replace specific verbose patterns first
    replacements = {
        'Qwen-Finetuned-GRPO': 'GRPO',
        'Qwen_Qwen2_5-1_5B-Instruct': 'Qwen2.5 1.5B Instruct',
        'Best_Move_Checkpoint_Trained_from_Base': 'Best Move (Base)',
        'Best_Move_Checkpoint_Trained_from_Legal_Move': 'Best Move (Legal)',
        'Ckpt': 'Ckpt',  # keep but ensure consistent capitalization
    }
    for k, v in replacements.items():
        model_dir = model_dir.replace(k, v)

    # Split on double underscores to define line groups
    parts = model_dir.split('__')
    cleaned_parts = []
    for part in parts:
        # Replace remaining underscores with spaces
        part = part.replace('_', ' ')
        part = part.replace('-', ' ')
        # Normalize multiple spaces
        part = re.sub(r'\s+', ' ', part).strip()
        if part:
            cleaned_parts.append(part)

    # Heuristic: if a line is still too long (> 28 chars), attempt secondary break on spaces
    folded_parts = []
    max_len = 28
    for cp in cleaned_parts:
        if len(cp) <= max_len:
            folded_parts.append(cp)
        else:
            words = cp.split(' ')
            line = ''
            lines = []
            for w in words:
                candidate = f"{line} {w}".strip()
                if len(candidate) > max_len and line:
                    lines.append(line)
                    line = w
                else:
                    line = candidate
            if line:
                lines.append(line)
            folded_parts.append('\n'.join(lines))

    formatted = '\n'.join(folded_parts)
    # Final fall back if something goes wrong
    if not formatted:
        formatted = original
    return formatted


def find_full_model_dirs(root_dir, hints):
    """Finds the full directory names in root_dir that match the hints."""
    try:
        actual_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        ordered_dirs = []
        for hint in hints:
            found = False
            # First: prefer exact match
            if hint in actual_dirs and hint not in ordered_dirs:
                ordered_dirs.append(hint)
                found = True
            else:
                # Fallback: prefix match (e.g., hint without final checkpoint suffix)
                for actual in actual_dirs:
                    if actual.startswith(hint) and actual not in ordered_dirs:
                        ordered_dirs.append(actual)
                        found = True
                        break
            if not found:
                print(f"Warning: Could not find a directory for model hint: {hint}")
        return ordered_dirs
    except FileNotFoundError:
        return []

def collect_image_paths(root_dir):
    """Scans the directory and groups image paths by hypothesis, layer, and model."""
    image_data = defaultdict(dict)
    
    if not os.path.exists(root_dir):
        print(f"Error: Directory not found: {root_dir}")
        return {}

    for model_dir in os.listdir(root_dir):
        model_path = os.path.join(root_dir, model_dir)
        if not os.path.isdir(model_path):
            continue

        for layer_dir in os.listdir(model_path):
            layer_path = os.path.join(model_path, layer_dir)
            # Extract layer number using regex
            match = re.search(r'layer_(\d+)', layer_dir)
            if not match:
                continue
            layer_num = int(match.group(1))

            for image_file in os.listdir(layer_path):
                if image_file.endswith('.png'):
                    # The hypothesis is the image filename itself
                    hypothesis_key = image_file
                    image_data[hypothesis_key][(layer_num, model_dir)] = os.path.join(layer_path, image_file)
    
    return image_data

def create_comparison_grid(hypothesis_name, image_map, models, layers):
    """Creates and saves a single grid image for a given hypothesis."""
    num_rows = len(layers)
    num_cols = len(models)

    # Create a subplot grid. Adjust figsize as needed.
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, num_rows * 4.5))
    fig.suptitle(hypothesis_name.replace('_', ' ').replace('.png', ''), fontsize=24, y=0.95)

    for r, layer in enumerate(layers):
        for c, model in enumerate(models):
            ax = axes[r, c]
            image_path = image_map.get((layer, model))

            if image_path and os.path.exists(image_path):
                img = mpimg.imread(image_path)
                ax.imshow(img, rasterized=True)
            else:
                # Handle missing images gracefully
                ax.text(0.5, 0.5, 'Image not found', ha='center', va='center', fontsize=12, color='red')

            ax.axis('off') # Hide the axes

            # Set column titles (models) on the first row
            if r == 0:
                # Multi-line formatted model name
                clean_model_name = format_model_name(model)
                ax.set_title(clean_model_name, fontsize=12, pad=12, loc='center', wrap=True)

        # Set row titles (layers) on the far left
        axes[r, 0].text(-0.05, 0.5, f"Layer {layer}", transform=axes[r, 0].transAxes,
                        ha='right', va='center', fontsize=16, rotation=90)

    plt.tight_layout(rect=[0.05, 0.02, 1, 0.90]) # Adjust layout to make space for titles
    
    # Save the final grid
    output_filename = os.path.join(OUTPUT_DIR, f"comparison_{hypothesis_name}")
    # plt.savefig(output_filename, dpi=600, bbox_inches='tight')
    
    # Also save as rasterized pdf
    pdf_output_filename = output_filename.replace('.png', '.pdf')
    plt.savefig(pdf_output_filename, dpi=300, bbox_inches='tight', format='pdf')
    
    plt.close(fig)
    print(f"Saved grid to {output_filename} and {pdf_output_filename}")


def main():
    """Main function to generate all comparison grids."""
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    print("Scanning for visualization files...")
    grouped_images = collect_image_paths(VISUALIZATIONS_DIR)
    
    if not grouped_images:
        print("No images found. Did you run the visualization script first?")
        return

    # Get the definitive list of models and layers from the found directories
    full_model_names = find_full_model_dirs(VISUALIZATIONS_DIR, MODEL_ORDER_HINT)
    # Enforce that ordering strictly follows MODEL_ORDER_HINT (ignoring missing)
    hint_ordered = [m for m in MODEL_ORDER_HINT if any(actual.startswith(m) for actual in full_model_names)]
    # Map hints to actual full names, prefer exact matches over prefix matches
    resolved = []
    for hint in hint_ordered:
        # Exact match first
        exact = [a for a in full_model_names if a == hint]
        if exact:
            resolved.append(exact[0])
            continue
        # Fallback: first prefix match
        for actual in full_model_names:
            if actual.startswith(hint):
                resolved.append(actual)
                break
    full_model_names = resolved
    missing = [m for m in MODEL_ORDER_HINT if m not in hint_ordered]
    if missing:
        print(f"Warning: The following model hints had no matching directories and will be omitted: {missing}")
    
    print(f"\nFound {len(grouped_images)} unique visualization types.")
    print(f"Models for grid (columns): {full_model_names}")
    print(f"Layers for grid (rows): {LAYER_ORDER}\n")

    for hypothesis, image_map in grouped_images.items():
        if "prediction" not in hypothesis.lower() and "illegal" not in hypothesis.lower():
            create_comparison_grid(hypothesis, image_map, full_model_names, LAYER_ORDER)

    print("\nAll comparison grids have been generated successfully!")


if __name__ == "__main__":
    main()
