"""
Summary Generator - Level 3 visualizations for reports and data quality analysis.

This module generates summary visualizations including:
- Performance summary tables
- Data coverage reports
- Sample size analysis
- Cross-modal analysis
"""

import os
import json
from typing import Dict, Any
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.lines import Line2D
from ..core.data_loader import split_by_question_type, get_available_values
from scripts.visualization.core.utils import (
    save_plot,
    get_color_palette,
    get_model_family_order,
    get_size_pattern_order,
)


def generate_performance_summary_table(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate comprehensive performance summary tables and save as text and JSON.

    Parameters:
    - data: Full evaluation dataset
    - output_dir: Directory to save summaries
    - verbose: Whether to print progress info

    Returns:
    - Dict with summary data and file paths
    """
    results = {"generated_files": [], "summary_data": {}}

    if verbose:
        print("   📊 Creating performance summary tables...")

    # Overall statistics
    overall_stats = {
        "total_evaluations": len(data),
        "overall_accuracy": data["correct"].mean(),
        "unique_models": len(data["model"].unique()),
        "unique_tasks": (
            len(data["benchmark"].unique()) if "benchmark" in data.columns else 0
        ),
        "unique_question_types": (
            len(data["question_type"].unique())
            if "question_type" in data.columns
            else 0
        ),
    }

    # Model performance summary
    model_summary = (
        data.groupby("model")["correct"].agg(["mean", "count", "sum", "std"]).round(4)
    )
    model_summary.columns = [
        "accuracy",
        "total_evaluations",
        "correct_count",
        "std_dev",
    ]
    model_summary = model_summary.sort_values("accuracy", ascending=False)

    # Task performance summary (if available)
    task_summary = None
    if "benchmark" in data.columns:
        task_summary = (
            data.groupby("benchmark")["correct"]
            .agg(["mean", "count", "sum", "std"])
            .round(4)
        )
        task_summary.columns = [
            "accuracy",
            "total_evaluations",
            "correct_count",
            "std_dev",
        ]
        task_summary = task_summary.sort_values("accuracy", ascending=False)

    # Question type summary (if available)
    question_summary = None
    if "question_type" in data.columns:
        question_summary = (
            data.groupby("question_type")["correct"]
            .agg(["mean", "count", "sum", "std"])
            .round(4)
        )
        question_summary.columns = [
            "accuracy",
            "total_evaluations",
            "correct_count",
            "std_dev",
        ]
        question_summary = question_summary.sort_values("accuracy", ascending=False)

    # Create text summary
    summary_text = f"""
=== GRAPH-BASED ARC EVALUATION SUMMARY ===

OVERALL STATISTICS:
• Total evaluations: {overall_stats['total_evaluations']:,}
• Overall accuracy: {overall_stats['overall_accuracy']:.4f}
• Unique models: {overall_stats['unique_models']}
• Unique tasks: {overall_stats['unique_tasks']}
• Unique question types: {overall_stats['unique_question_types']}

MODEL PERFORMANCE (sorted by accuracy):
{model_summary.to_string()}

"""

    if task_summary is not None:
        summary_text += f"""
TASK PERFORMANCE (sorted by accuracy):
{task_summary.to_string()}

"""

    if question_summary is not None:
        summary_text += f"""
QUESTION TYPE PERFORMANCE (sorted by accuracy):
{question_summary.to_string()}

"""

    # Add top and bottom performers
    if not model_summary.empty:
        best_model = model_summary.index[0]
        worst_model = model_summary.index[-1]
        summary_text += f"""
HIGHLIGHTS:
• Best performing model: {best_model} ({model_summary.loc[best_model, 'accuracy']:.4f})
• Lowest performing model: {worst_model} ({model_summary.loc[worst_model, 'accuracy']:.4f})
"""

    if task_summary is not None and not task_summary.empty:
        easiest_task = task_summary.index[0]
        hardest_task = task_summary.index[-1]
        summary_text += f"""• Easiest task: {easiest_task} ({task_summary.loc[easiest_task, 'accuracy']:.4f})
• Hardest task: {hardest_task} ({task_summary.loc[hardest_task, 'accuracy']:.4f})
"""

    # Save text summary
    text_path = f"{output_dir}/performance_summary.txt"
    with open(text_path, "w", encoding="utf-8") as f:
        f.write(summary_text)
    results["generated_files"].append(text_path)

    # Prepare data for JSON export
    summary_data = {
        "overall_statistics": overall_stats,
        "model_performance": model_summary.to_dict(),
        "task_performance": (
            task_summary.to_dict() if task_summary is not None else None
        ),
        "question_type_performance": (
            question_summary.to_dict() if question_summary is not None else None
        ),
    }

    # Save JSON summary
    json_path = f"{output_dir}/performance_summary.json"
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(summary_data, f, indent=2)
    results["generated_files"].append(json_path)

    results["summary_data"] = summary_data

    return results


def generate_data_coverage_report(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate data coverage analysis showing sample distributions.

    Parameters:
    - data: Full evaluation dataset
    - output_dir: Directory to save reports
    - verbose: Whether to print progress info

    Returns:
    - Dict with coverage analysis
    """
    results = {"generated_files": [], "coverage_data": {}}

    if verbose:
        print("   📈 Analyzing data coverage...")

    # Get available values
    available_values = get_available_values(data)

    # Create coverage matrix showing sample counts
    coverage_matrices = {}

    # Model x Task coverage
    if "benchmark" in data.columns:
        model_task_coverage = pd.crosstab(
            data["model"], data["benchmark"], margins=True
        )
        coverage_matrices["model_task"] = model_task_coverage

        # Create heatmap
        fig, ax = plt.subplots(figsize=(16, 8))
        # Exclude the 'All' row and column for the heatmap
        plot_data = model_task_coverage.iloc[:-1, :-1]
        sns.heatmap(plot_data, annot=True, fmt="d", cmap="Blues")
        if not no_titles:
            ax.set_title("Sample Coverage: Models × Tasks")

        filepath = f"{output_dir}/coverage_model_task.pdf"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

    # Model x Question Type coverage (if available)
    if "question_type" in data.columns:
        model_question_coverage = pd.crosstab(
            data["model"], data["question_type"], margins=True
        )
        coverage_matrices["model_question"] = model_question_coverage

        # Create heatmap
        fig, ax = plt.subplots(figsize=(14, 8))
        plot_data = model_question_coverage.iloc[:-1, :-1]
        sns.heatmap(plot_data, annot=True, fmt="d", cmap="Greens")
        if not no_titles:
            ax.set_title("Sample Coverage: Models × Question Types")

        filepath = f"{output_dir}/coverage_model_question.pdf"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

    # System Prompt x Question Type coverage (if available)
    if "system_prompt" in data.columns and "question_type" in data.columns:
        system_question_coverage = pd.crosstab(
            data["system_prompt"], data["question_type"], margins=True
        )
        coverage_matrices["system_question"] = system_question_coverage

        # Create heatmap
        fig, ax = plt.subplots(figsize=(12, 6))
        plot_data = system_question_coverage.iloc[:-1, :-1]
        sns.heatmap(plot_data, annot=True, fmt="d", cmap="Oranges")
        if not no_titles:
            ax.set_title("Sample Coverage: System Prompts × Question Types")

        filepath = f"{output_dir}/coverage_system_question.pdf"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

    # Create coverage report text
    coverage_text = """
=== DATA COVERAGE REPORT ===

SAMPLE DISTRIBUTION:
"""

    # Map plural dimension names to actual column names
    dimension_mapping = {
        "models": "model",
        "tasks": "benchmark",
        "question_types": "question_type",
        "targets": "target",
        "system_prompts": "system_prompt",
        "encodings": "encoding",
    }

    for dimension, values in available_values.items():
        if values:  # Only show dimensions that have data
            column_name = dimension_mapping.get(dimension, dimension)
            if column_name in data.columns:
                counts = data[column_name].value_counts()
                coverage_text += f"\n{dimension.upper()}:\n"
                for value in values:
                    count = counts.get(value, 0)
                    coverage_text += f"  • {value}: {count:,} samples\n"

    coverage_text += "\nCOVERAGE MATRICES:\n"
    for matrix_name, matrix in coverage_matrices.items():
        coverage_text += (
            f"\n{matrix_name.upper().replace('_', ' × ')}:\n{matrix.to_string()}\n\n"
        )

    # Save coverage report
    text_path = f"{output_dir}/data_coverage_report.txt"
    with open(text_path, "w", encoding="utf-8") as f:
        f.write(coverage_text)
    results["generated_files"].append(text_path)

    # Save coverage data as JSON
    coverage_data = {
        "available_values": available_values,
        "coverage_matrices": {
            name: matrix.to_dict() for name, matrix in coverage_matrices.items()
        },
    }

    json_path = f"{output_dir}/data_coverage.json"
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(coverage_data, f, indent=2)
    results["generated_files"].append(json_path)

    results["coverage_data"] = coverage_data

    return results


def generate_sample_size_analysis(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate analysis of sample sizes and statistical significance.

    Parameters:
    - data: Full evaluation dataset
    - output_dir: Directory to save analysis
    - verbose: Whether to print progress info

    Returns:
    - Dict with sample size analysis
    """
    results = {"generated_files": [], "sample_analysis": {}}

    if verbose:
        print("   🔬 Analyzing sample sizes...")

    # Calculate minimum sample sizes for different groupings
    sample_size_thresholds = [1, 5, 10, 20, 50]

    groupings = []
    if "model" in data.columns:
        groupings.append("model")
    if "benchmark" in data.columns:
        groupings.append("benchmark")
    if "question_type" in data.columns:
        groupings.append("question_type")
    if "system_prompt" in data.columns:
        groupings.append("system_prompt")

    sample_analysis = {}

    for grouping in groupings:
        group_sizes = data[grouping].value_counts().sort_values(ascending=False)

        analysis = {
            "total_groups": len(group_sizes),
            "mean_sample_size": group_sizes.mean(),
            "median_sample_size": group_sizes.median(),
            "min_sample_size": group_sizes.min(),
            "max_sample_size": group_sizes.max(),
            "groups_by_threshold": {},
        }

        for threshold in sample_size_thresholds:
            groups_above_threshold = (group_sizes >= threshold).sum()
            analysis["groups_by_threshold"][threshold] = {
                "count": groups_above_threshold,
                "percentage": (groups_above_threshold / len(group_sizes)) * 100,
            }

        sample_analysis[grouping] = analysis

    # Create sample size distribution charts
    for grouping in groupings[:3]:  # Limit to first 3 to avoid too many charts
        if grouping in data.columns:
            group_sizes = data[grouping].value_counts().sort_values(ascending=False)

            fig, ax = plt.subplots(figsize=(12, 6))
            bars = ax.bar(
                range(len(group_sizes)), group_sizes.values, color="skyblue", alpha=0.7
            )

            if not no_titles:
                ax.set_title(f"Sample Size Distribution by {grouping.title()}")
            ax.set_xlabel(f"{grouping.title()} (sorted by sample size)")
            ax.set_ylabel("Number of Samples")

            # FIXED: Only rotate labels if there are many items OR if it's not models
            if grouping == "model":
                ax.set_xticks(range(len(group_sizes)))
                ax.set_xticklabels(group_sizes.index, rotation=0, ha="center")
            else:
                ax.set_xticks(range(len(group_sizes)))
                ax.set_xticklabels(group_sizes.index, rotation=45, ha="right")

            # Add value labels on bars
            for _, bar in enumerate(bars):
                height = bar.get_height()
                ax.text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height + height * 0.01,
                    f"{int(height)}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

            filepath = f"{output_dir}/sample_sizes_{grouping}.pdf"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

    # Create sample size analysis text
    analysis_text = """
=== SAMPLE SIZE ANALYSIS ===

STATISTICAL RELIABILITY ASSESSMENT:

"""

    for grouping, analysis in sample_analysis.items():
        analysis_text += f"""
{grouping.upper()}:
• Total groups: {analysis['total_groups']}
• Mean sample size: {analysis['mean_sample_size']:.1f}
• Median sample size: {analysis['median_sample_size']:.1f}
• Range: {analysis['min_sample_size']} - {analysis['max_sample_size']}

Groups with sufficient samples:"""

        for threshold, data_point in analysis["groups_by_threshold"].items():
            analysis_text += f"""
  • ≥{threshold} samples: {data_point['count']} groups ({data_point['percentage']:.1f}%)"""

        analysis_text += "\n"

    # Add recommendations
    analysis_text += """
RECOMMENDATIONS:
• For statistical significance, consider groups with ≥10 samples
• For robust comparisons, prefer groups with ≥20 samples  
• Groups with <5 samples should be interpreted cautiously
• Consider aggregating small groups for more reliable analysis
"""

    # Save analysis
    text_path = f"{output_dir}/sample_size_analysis.txt"
    with open(text_path, "w", encoding="utf-8") as f:
        f.write(analysis_text)
    results["generated_files"].append(text_path)

    # Save analysis data as JSON (convert numpy types to Python types)
    def convert_numpy_types(obj):
        """Convert numpy types to Python native types for JSON serialization."""
        if hasattr(obj, "item"):  # numpy scalar
            return obj.item()
        elif hasattr(obj, "tolist"):  # numpy array
            return obj.tolist()
        elif isinstance(obj, dict):
            return {key: convert_numpy_types(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy_types(item) for item in obj]
        else:
            return obj

    json_path = f"{output_dir}/sample_size_analysis.json"
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(convert_numpy_types(sample_analysis), f, indent=2)
    results["generated_files"].append(json_path)

    results["sample_analysis"] = sample_analysis

    return results


def generate_pattern_example_analysis(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate analyses for size patterns and number of examples (n_pairs).
    IMPROVED VERSION with better readability and professional styling.
    """
    results = {"generated_files": [], "pattern_data": {}, "example_data": {}}

    if verbose:
        print("   🎨 Creating improved pattern and example size analyses...")

    # Get model ordering and colors using the enhanced color system
    models = get_model_family_order(data["model"].unique())
    model_colors = get_color_palette(models, "models")

    # ========================================================================
    # 1. MODEL PERFORMANCE BY SIZE PATTERN - Enhanced Version
    # ========================================================================

    fig, ax = plt.subplots(figsize=(14, 8))

    # Calculate statistics for each pattern-model combination
    pattern_model_stats = (
        data.groupby(["size_pattern", "model"])["correct"]
        .agg(["mean", "count", "std"])
        .reset_index()
    )

    # Get unique patterns and sort them appropriately
    patterns = get_size_pattern_order(data["size_pattern"].unique())

    # Set up positions for grouped bars
    x_positions = np.arange(len(patterns))
    bar_width = 0.8 / len(models)

    # Create bars for each model
    for i, model in enumerate(models):
        model_data = pattern_model_stats[pattern_model_stats["model"] == model]

        # Align data with pattern order
        accuracies = []
        sample_counts = []

        for pattern in patterns:
            pattern_data = model_data[model_data["size_pattern"] == pattern]
            if not pattern_data.empty:
                accuracies.append(pattern_data["mean"].iloc[0])
                sample_counts.append(pattern_data["count"].iloc[0])
            else:
                accuracies.append(0)
                sample_counts.append(0)

        # Calculate x positions for this model's bars
        x_pos = x_positions + (i - (len(models) - 1) / 2) * bar_width

        # Create bars
        bars = ax.bar(
            x_pos,
            accuracies,
            bar_width,
            label=model,
            color=model_colors[model],
            alpha=0.85,
        )

        # Add annotations
        for _, (bar, acc, count) in enumerate(zip(bars, accuracies, sample_counts)):
            if acc > 0:  # Only annotate non-zero bars
                height = bar.get_height()

                # Accuracy label on top
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    height + 0.01,
                    f"{acc:.3f}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                    fontweight="bold",
                )

                # Sample size at bottom (if bar is tall enough)
                if height > 0.15:
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        0.02,
                        f"n={count}",
                        ha="center",
                        va="bottom",
                        fontsize=7,
                        color="darkgray",
                        fontweight="bold",
                    )

    # Styling
    ax.set_xlabel("Size Pattern", fontsize=12, fontweight="bold")
    ax.set_ylabel("Accuracy", fontsize=12, fontweight="bold")
    if not no_titles:
        ax.set_title(
            "Model Performance by Size Pattern", fontsize=14, fontweight="bold", pad=20
        )
    ax.set_xticks(x_positions)
    ax.set_xticklabels(patterns, fontsize=10)
    ax.set_ylim(0, 1.05)

    # Format y-axis as percentages
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))

    # Add subtle grid
    ax.grid(True, axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
    ax.set_axisbelow(True)

    # Improve legend
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Models")

    # Clean up spines
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    fp1 = f"{output_dir}/performance_by_pattern_model_improved.pdf"
    save_plot(fig, fp1, no_titles=no_titles)
    results["generated_files"].append(fp1)

    # ========================================================================
    # 2a. AVERAGE PERFORMANCE BY PATTERN - Full Output Tasks Only
    # ========================================================================

    # Split data by question type
    full_output_data, question_data = split_by_question_type(data)

    # Initialize variables to avoid NameError later
    full_output_pattern_stats = pd.DataFrame()
    question_pattern_stats = pd.DataFrame()

    if not full_output_data.empty:
        fig, ax = plt.subplots(figsize=(12, 6))

        # Calculate statistics by pattern for full output tasks
        full_output_pattern_stats = (
            full_output_data.groupby("size_pattern")["correct"]
            .agg(["mean", "count"])
            .reset_index()
        )

        # Sort by pattern order
        full_output_pattern_stats = (
            full_output_pattern_stats.set_index("size_pattern")
            .reindex(patterns)
            .reset_index()
        )

        # Create bars
        bars = ax.bar(
            range(len(full_output_pattern_stats)),
            full_output_pattern_stats["mean"],
            color="steelblue",
            alpha=0.8,
        )

        # Add simple annotations
        for i, (bar, acc, count) in enumerate(
            zip(
                bars,
                full_output_pattern_stats["mean"],
                full_output_pattern_stats["count"],
            )
        ):
            height = bar.get_height()

            # Accuracy percentage on top
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                height + 0.02,
                f"{acc:.1%}",
                ha="center",
                va="bottom",
                fontsize=11,
                fontweight="bold",
            )

            # Sample size at bottom
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                0.02,
                f"n={count}",
                ha="center",
                va="bottom",
                fontsize=9,
                color="darkgray",
                fontweight="bold",
            )

        # Styling
        ax.set_xlabel("Size Pattern", fontsize=12, fontweight="bold")
        ax.set_ylabel("Average Accuracy", fontsize=12, fontweight="bold")
        if not no_titles:
            ax.set_title(
                "Full Output Performance by Size Pattern",
                fontsize=14,
                fontweight="bold",
                pad=20,
            )
        ax.set_xticks(range(len(full_output_pattern_stats)))
        ax.set_xticklabels(patterns, fontsize=11, fontweight="bold")
        ax.set_ylim(0, 1.05)

        # Format y-axis as percentages
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))

        # Add subtle grid
        ax.grid(True, axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
        ax.set_axisbelow(True)

        # Clean up spines
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        fp2a = f"{output_dir}/average_performance_by_pattern_full_output.pdf"
        save_plot(fig, fp2a, no_titles=no_titles)
        results["generated_files"].append(fp2a)

    # ========================================================================
    # 2b. AVERAGE PERFORMANCE BY PATTERN - Question-Based Tasks Only
    # ========================================================================

    if not question_data.empty:
        fig, ax = plt.subplots(figsize=(12, 6))

        # Calculate statistics by pattern for question-based tasks
        question_pattern_stats = (
            question_data.groupby("size_pattern")["correct"]
            .agg(["mean", "count"])
            .reset_index()
        )

        # Sort by pattern order
        question_pattern_stats = (
            question_pattern_stats.set_index("size_pattern")
            .reindex(patterns)
            .reset_index()
        )

        # Create bars
        bars = ax.bar(
            range(len(question_pattern_stats)),
            question_pattern_stats["mean"],
            color="darkorange",
            alpha=0.8,
        )

        # Add simple annotations
        for i, (bar, acc, count) in enumerate(
            zip(bars, question_pattern_stats["mean"], question_pattern_stats["count"])
        ):
            height = bar.get_height()

            # Accuracy percentage on top
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                height + 0.02,
                f"{acc:.1%}",
                ha="center",
                va="bottom",
                fontsize=11,
                fontweight="bold",
            )

            # Sample size at bottom
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                0.02,
                f"n={count}",
                ha="center",
                va="bottom",
                fontsize=9,
                color="darkgray",
                fontweight="bold",
            )

        # Styling
        ax.set_xlabel("Size Pattern", fontsize=12, fontweight="bold")
        ax.set_ylabel("Average Accuracy", fontsize=12, fontweight="bold")
        if not no_titles:
            ax.set_title(
                "Question-Based Performance by Size Pattern",
                fontsize=14,
                fontweight="bold",
                pad=20,
            )
        ax.set_xticks(range(len(question_pattern_stats)))
        ax.set_xticklabels(patterns, fontsize=11, fontweight="bold")
        ax.set_ylim(0, 1.05)

        # Format y-axis as percentages
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))

        # Add subtle grid
        ax.grid(True, axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
        ax.set_axisbelow(True)

        # Clean up spines
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        fp2b = f"{output_dir}/average_performance_by_pattern_question_based.pdf"
        save_plot(fig, fp2b, no_titles=no_titles)
        results["generated_files"].append(fp2b)

    # ========================================================================
    # 3. PERFORMANCE VS NUMBER OF EXAMPLES - Line Chart
    # ========================================================================

    fig, ax = plt.subplots(figsize=(12, 7))

    # Calculate stats by n_pairs and model
    example_stats = (
        data.groupby(["n_pairs", "model"])["correct"]
        .agg(["mean", "count"])
        .reset_index()
    )

    # Plot lines for each model
    for model in models:
        model_data = example_stats[example_stats["model"] == model]
        if not model_data.empty:
            ax.plot(
                model_data["n_pairs"],
                model_data["mean"],
                marker="o",
                markersize=8,
                linewidth=2.5,
                label=model,
                color=model_colors[model],
                alpha=0.8,
            )

            # Add value annotations at each point
            for _, row in model_data.iterrows():
                ax.annotate(
                    f'{row["mean"]:.3f}',
                    (row["n_pairs"], row["mean"]),
                    textcoords="offset points",
                    xytext=(0, 10),
                    ha="center",
                    fontsize=8,
                    fontweight="bold",
                    color=model_colors[model],
                )

    # Styling
    ax.set_xlabel("Number of Examples", fontsize=12, fontweight="bold")
    ax.set_ylabel("Accuracy", fontsize=12, fontweight="bold")
    if not no_titles:
        ax.set_title(
            "Model Performance vs Number of Examples",
            fontsize=14,
            fontweight="bold",
            pad=20,
        )
    ax.set_ylim(0, 1.05)

    # Format axes
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))
    available_n_pairs = sorted(data["n_pairs"].unique())
    ax.set_xticks(available_n_pairs)
    ax.set_xticklabels([f"{n}" for n in available_n_pairs], fontsize=11)

    # Add grid
    ax.grid(True, alpha=0.3, linestyle="--", linewidth=0.5)
    ax.set_axisbelow(True)

    # Legend
    ax.legend(title="Models", loc="best")

    # Clean up spines
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    fp3 = f"{output_dir}/performance_vs_num_examples_improved.pdf"
    save_plot(fig, fp3, no_titles=no_titles)
    results["generated_files"].append(fp3)

    # ========================================================================
    # 4. TASK-SPECIFIC PERFORMANCE VS EXAMPLES - Heatmap
    # ========================================================================

    fig, ax = plt.subplots(figsize=(12, 8))

    # Calculate task performance by number of examples
    task_example_stats = (
        data.groupby(["benchmark", "n_pairs"])["correct"]
        .agg(["mean", "count"])
        .reset_index()
    )

    # Create pivot table for heatmap
    task_pivot = task_example_stats.pivot(
        index="benchmark", columns="n_pairs", values="mean"
    )

    # Create heatmap with better styling
    sns.heatmap(
        task_pivot,
        annot=True,
        fmt=".3f",
        cmap="RdYlBu_r",
        center=0.5,
        cbar_kws={"label": "Accuracy"},
        ax=ax,
    )

    # Styling
    ax.set_xlabel("Number of Examples", fontsize=12, fontweight="bold")
    ax.set_ylabel("Task", fontsize=12, fontweight="bold")
    if not no_titles:
        ax.set_title(
            "Task Performance vs Number of Examples",
            fontsize=14,
            fontweight="bold",
            pad=20,
        )

    # Rotate labels for better readability
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=10)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=9, rotation=0)

    fp4 = f"{output_dir}/task_performance_vs_examples_heatmap.pdf"
    save_plot(fig, fp4, no_titles=no_titles)
    results["generated_files"].append(fp4)

    # ========================================================================
    # Store results
    # ========================================================================

    results["pattern_data"]["by_pattern_model"] = pattern_model_stats.to_dict("records")

    # Store separate stats for full_output and question_based if they exist
    if not full_output_data.empty and not full_output_pattern_stats.empty:
        results["pattern_data"]["full_output_by_pattern"] = (
            full_output_pattern_stats.to_dict("records")
        )

    if not question_data.empty and not question_pattern_stats.empty:
        results["pattern_data"]["question_based_by_pattern"] = (
            question_pattern_stats.to_dict("records")
        )

    results["example_data"]["by_n_pairs_model"] = example_stats.to_dict("records")
    results["example_data"]["by_n_pairs_task"] = task_example_stats.to_dict("records")

    if verbose:
        print(
            f"   ✅ Generated {len(results['generated_files'])} improved pattern/example analysis files"
        )

    return results


def generate_cross_modal_analysis(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate cross-modal analysis comparing same tasks with different question types.
    Enhanced with better marker handling and alphabetical ordering.
    """
    results = {"generated_files": [], "cross_modal_data": {}}

    if verbose:
        print("   🔀 Performing cross-modal analysis...")

    # Split by question type
    full_output_data, question_data = split_by_question_type(data)

    if full_output_data.empty or question_data.empty:
        if verbose:
            print("   ⚠️ Insufficient data for cross-modal analysis")
        return results

    # Find tasks that have both full_output and question-based evaluations
    full_output_tasks = (
        set(full_output_data["benchmark"].unique())
        if "benchmark" in full_output_data.columns
        else set()
    )
    question_tasks = (
        set(question_data["benchmark"].unique())
        if "benchmark" in question_data.columns
        else set()
    )
    common_tasks = full_output_tasks.intersection(question_tasks)

    if not common_tasks:
        if verbose:
            print("   ⚠️ No tasks found with both full_output and question-based data")
        return results

    # Sort tasks alphabetically for consistent ordering
    common_tasks = sorted(list(common_tasks))

    # Analyze performance correlation between full_output and question-based approaches
    cross_modal_comparison = []

    for task in common_tasks:
        full_output_task_data = full_output_data[full_output_data["benchmark"] == task]
        question_task_data = question_data[question_data["benchmark"] == task]

        # Calculate performance for each model on this task
        full_output_perf = full_output_task_data.groupby("model")["correct"].mean()
        question_perf = question_task_data.groupby("model")["correct"].mean()

        # Find common models
        common_models = set(full_output_perf.index).intersection(
            set(question_perf.index)
        )

        for model in common_models:
            cross_modal_comparison.append(
                {
                    "task": task,
                    "model": model,
                    "full_output_accuracy": full_output_perf[model],
                    "question_based_accuracy": question_perf[model],
                    "difference": question_perf[model] - full_output_perf[model],
                }
            )

    if cross_modal_comparison:
        cross_modal_df = pd.DataFrame(cross_modal_comparison)

        # Create the improved visualization
        create_enhanced_cross_modal_plot(cross_modal_df, output_dir, results, no_titles)

        # Calculate correlation statistics
        correlation = cross_modal_df["full_output_accuracy"].corr(
            cross_modal_df["question_based_accuracy"]
        )
        mean_difference = cross_modal_df["difference"].mean()

        results["cross_modal_data"] = {
            "correlation": correlation,
            "mean_difference": mean_difference,
            "comparison_data": cross_modal_comparison,
            "tasks_analyzed": common_tasks,  # Already sorted alphabetically
        }

        # Create analysis text
        analysis_text = f"""
=== CROSS-MODAL ANALYSIS ===

COMPARISON: Full Output vs Question-Based Approaches

SUMMARY STATISTICS:
• Tasks analyzed: {len(common_tasks)}
• Model-task combinations: {len(cross_modal_comparison)}
• Performance correlation: {correlation:.4f}
• Average difference (Question - Full): {mean_difference:.4f}

INTERPRETATION:
• Correlation > 0.7: Strong agreement between approaches
• Correlation 0.3-0.7: Moderate agreement  
• Correlation < 0.3: Weak agreement (different skills tested)

TASKS INCLUDED (alphabetical order):
{chr(10).join(f'• {task}' for task in common_tasks)}

DETAILED COMPARISON:
"""

        # Add top differences
        cross_modal_df_sorted = cross_modal_df.sort_values(
            "difference", ascending=False
        )
        analysis_text += "\nLargest improvements with question-based approach:\n"
        for _, row in cross_modal_df_sorted.head(5).iterrows():
            analysis_text += (
                f"• {row['model']} on {row['task']}: +{row['difference']:.3f}\n"
            )

        analysis_text += "\nLargest decreases with question-based approach:\n"
        for _, row in cross_modal_df_sorted.tail(5).iterrows():
            analysis_text += (
                f"• {row['model']} on {row['task']}: {row['difference']:.3f}\n"
            )

        # Save analysis
        text_path = f"{output_dir}/cross_modal_analysis.txt"
        with open(text_path, "w", encoding="utf-8") as f:
            f.write(analysis_text)
        results["generated_files"].append(text_path)

    return results


def create_enhanced_cross_modal_plot(
    cross_modal_df: pd.DataFrame,
    output_dir: str,
    results: Dict,
    no_titles: bool = False,
):
    """
    Create an enhanced cross-modal comparison plot with better marker handling.
    """
    # Sort data for consistent display
    models = sorted(cross_modal_df["model"].unique())
    tasks = sorted(cross_modal_df["task"].unique())  # Alphabetical order

    # Get model colors using the enhanced color system
    model_colors = get_color_palette(models, "models")

    # Create figure with two subplots: main plot and legend-focused plot
    fig = plt.figure(figsize=(20, 10))

    # Main scatter plot (larger)
    ax_main = plt.subplot2grid((2, 3), (0, 0), colspan=2, rowspan=2)

    # Model legend subplot (smaller)
    ax_model_legend = plt.subplot2grid((2, 3), (0, 2))
    ax_model_legend.axis("off")

    # Task legend subplot (smaller)
    ax_task_legend = plt.subplot2grid((2, 3), (1, 2))
    ax_task_legend.axis("off")

    # Create main scatter plot with model colors only (no task markers)
    for model in models:
        model_data = cross_modal_df[cross_modal_df["model"] == model]
        ax_main.scatter(
            model_data["full_output_accuracy"],
            model_data["question_based_accuracy"],
            color=model_colors[model],
            s=80,
            alpha=0.7,
            edgecolors="black",
            linewidths=0.5,
            label=model,
        )

    # Add task labels as text annotations
    for _, row in cross_modal_df.iterrows():
        # Create abbreviated task name for cleaner display
        task_abbrev = abbreviate_task_name(row["task"])
        ax_main.annotate(
            task_abbrev,
            (row["full_output_accuracy"], row["question_based_accuracy"]),
            xytext=(3, 3),
            textcoords="offset points",
            fontsize=6,
            alpha=0.8,
            bbox=dict(
                boxstyle="round,pad=0.2", facecolor="white", alpha=0.7, edgecolor="none"
            ),
        )

    # Add diagonal line (perfect correlation)
    min_val = min(
        cross_modal_df["full_output_accuracy"].min(),
        cross_modal_df["question_based_accuracy"].min(),
    )
    max_val = max(
        cross_modal_df["full_output_accuracy"].max(),
        cross_modal_df["question_based_accuracy"].max(),
    )
    ax_main.plot(
        [min_val, max_val],
        [min_val, max_val],
        "k--",
        alpha=0.5,
        linewidth=2,
        label="Perfect Correlation",
    )

    # Customize main plot
    ax_main.set_xlabel("Full Output Accuracy", fontsize=12)
    ax_main.set_ylabel("Question-Based Accuracy", fontsize=12)
    if not no_titles:
        ax_main.set_title(
            "Cross-Modal Performance Comparison\nFull Output vs Question-Based Approaches",
            fontsize=14,
        )
    ax_main.grid(True, alpha=0.3)

    # Create model legend
    model_handles = [
        Line2D(
            [0],
            [0],
            marker="o",
            color=model_colors[model],
            linestyle="",
            markersize=8,
            label=model,
            markeredgecolor="black",
            markeredgewidth=0.5,
        )
        for model in models
    ]

    ax_model_legend.legend(
        handles=model_handles,
        title="Models",
        loc="center",
        frameon=True,
        fancybox=True,
        shadow=True,
    )
    ax_model_legend.set_title("Model Legend", fontweight="bold", pad=20)

    # Create task abbreviation legend
    task_abbrev_mapping = {task: abbreviate_task_name(task) for task in tasks}
    task_legend_text = "Task Abbreviations:\n\n"

    # Group tasks for better readability
    for i, (task, abbrev) in enumerate(task_abbrev_mapping.items()):
        task_legend_text += f"{abbrev}: {task}\n"
        if (i + 1) % 10 == 0:  # Add spacing every 10 items
            task_legend_text += "\n"

    ax_task_legend.text(
        0.05,
        0.95,
        task_legend_text,
        transform=ax_task_legend.transAxes,
        verticalalignment="top",
        fontsize=8,
        bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.3),
    )
    ax_task_legend.set_title("Task Legend", fontweight="bold", pad=20)

    filepath = f"{output_dir}/cross_modal_comparison.pdf"
    save_plot(fig, filepath, no_titles=no_titles)
    results["generated_files"].append(filepath)


def abbreviate_task_name(task_name: str) -> str:
    """
    Create meaningful abbreviations for task names.
    """
    # Common abbreviation patterns
    abbreviations = {
        "color": "c",
        "degree": "d",
        "leaves": "lv",
        "internal": "int",
        "neighbors": "nb",
        "components": "comp",
        "distance": "dist",
        "equidistant": "eqd",
        "bipartition": "bip",
        "completion": "cmp",
        "subgraph": "sub",
        "complement": "cpl",
        "remove": "rm",
        "same": "s",
        "edge": "e",
        "node": "n",
        "hub": "h",
        "merge": "mrg",
        "blue": "b",
        "at": "@",
        "least": "≥",
        "maximum": "max",
        "minimum": "min",
    }

    # Apply abbreviations
    abbrev = task_name.lower()
    for full_word, short in abbreviations.items():
        abbrev = abbrev.replace(full_word, short)

    # Remove common words
    remove_words = ["and", "the", "to", "of", "in", "on", "at", "for", "with"]
    words = abbrev.split()
    words = [w for w in words if w not in remove_words]

    # Take first few characters if still too long
    if len("".join(words)) > 8:
        if len(words) > 1:
            abbrev = "".join([w[:2] for w in words[:3]])
        else:
            abbrev = words[0][:8]
    else:
        abbrev = "".join(words)

    return abbrev.upper()


def generate_summary_visualizations(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate all summary visualizations and reports.

    Parameters:
    - data: Full evaluation dataset
    - output_dir: Directory to save visualizations
    - verbose: Whether to print progress info

    Returns:
    - Dict with generation summary
    """
    if verbose:
        print("📈 Generating Level 3 Summary Visualizations...")

    # Create summary subdirectory
    summary_dir = os.path.join(output_dir, "summary")
    os.makedirs(summary_dir, exist_ok=True)

    all_results = {
        "summary_dir": summary_dir,
        "performance_summary": {},
        "data_coverage": {},
        "sample_analysis": {},
        "pattern_example_analysis": {},
        "cross_modal_analysis": {},
        "total_files": 0,
    }

    # Generate performance summary
    if verbose:
        print("  📊 Creating performance summary...")
    perf_results = generate_performance_summary_table(
        data, summary_dir, verbose, no_titles
    )
    all_results["performance_summary"] = perf_results

    # Generate data coverage report
    if verbose:
        print("  📈 Creating data coverage report...")
    coverage_results = generate_data_coverage_report(
        data, summary_dir, verbose, no_titles
    )
    all_results["data_coverage"] = coverage_results

    # Generate sample size analysis
    if verbose:
        print("  🔬 Creating sample size analysis...")
    sample_results = generate_sample_size_analysis(
        data, summary_dir, verbose, no_titles
    )
    all_results["sample_analysis"] = sample_results

    # Generate pattern and example size analyses
    if verbose:
        print("  🎨 Creating pattern and example size analyses...")
    pattern_results = generate_pattern_example_analysis(
        data, summary_dir, verbose, no_titles
    )
    all_results["pattern_example_analysis"] = pattern_results

    # Generate cross-modal analysis
    if verbose:
        print("  🔀 Creating cross-modal analysis...")
    cross_modal_results = generate_cross_modal_analysis(
        data, summary_dir, verbose, no_titles
    )
    all_results["cross_modal_analysis"] = cross_modal_results

    # Calculate total files
    total_files = (
        len(perf_results["generated_files"])
        + len(coverage_results["generated_files"])
        + len(sample_results["generated_files"])
        + len(pattern_results["generated_files"])
        + len(cross_modal_results["generated_files"])
    )
    all_results["total_files"] = total_files

    if verbose:
        print(f"  ✅ Generated {total_files} summary visualizations in {summary_dir}")

    return all_results
