"""
Overview Generator - Level 1 visualizations for high-level performance insights.

UPDATED to provide separate views for full_output and question-based prompts with clear naming.
UPDATED to support no_titles option and model name mapping.
"""

import os
from typing import Dict, Any, Tuple
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scripts.visualization.core.chart_builders import (
    create_model_performance_chart,
    create_comparison_chart,
    create_breakdown_chart,
)
from scripts.visualization.core.data_loader import split_by_question_type
from scripts.visualization.core.utils import save_plot
from scripts.visualization.core.token_visualization_utils import (
    prepare_token_data,
    create_task_difficulty_analysis,
    create_question_type_reasoning_analysis,
)
from scripts.visualization.core.utils import (
    get_model_family_order,
    get_color_palette,
    apply_display_names_to_list,
)


def generate_token_overview(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate token analysis visualizations for overview.

    Parameters:
    - data: Full evaluation dataset
    - output_dir: Directory to save visualizations
    - verbose: Whether to print progress info
    - no_titles: Whether to suppress titles

    Returns:
    - Dict with generation summary
    """
    results = {"generated_files": [], "stats": {}}

    if verbose:
        print("   🧠 Analyzing token usage patterns...")

    # Prepare token data
    token_data = prepare_token_data(data)

    if token_data.empty:
        if verbose:
            print("   ⚠️ No token data found in evaluation results")
        return results

    # Store basic token stats
    results["stats"]["total_responses_with_tokens"] = len(token_data)
    results["stats"]["avg_total_tokens"] = token_data["total_tokens"].mean()
    results["stats"]["avg_reasoning_tokens"] = token_data["reasoning_tokens"].mean()
    results["stats"]["responses_with_reasoning"] = (
        token_data["reasoning_tokens"] > 0
    ).sum()

    # Split by question type for separate analyses
    full_output_token_data, question_token_data = split_by_question_type(token_data)

    # 1. Task Difficulty vs Token Requirements (FULL OUTPUT ONLY)
    if (
        not full_output_token_data.empty
        and "benchmark" in full_output_token_data.columns
    ):
        if verbose:
            print("   📈 Creating task difficulty analysis (full output)...")

        fig = create_task_difficulty_analysis(full_output_token_data, figsize=(16, 8))

        # Apply no_titles if requested
        if no_titles:
            fig.suptitle("")
            for ax in fig.get_axes():
                ax.set_title("")

        filepath = f"{output_dir}/07_task_difficulty_vs_tokens_full_output.png"
        title = (
            "Task Difficulty vs Token Requirements (Full Output Tasks)"
            if not no_titles
            else None
        )
        save_plot(fig, filepath, title, no_titles=no_titles)
        results["generated_files"].append(filepath)

    # 2a. Question Type Reasoning Analysis (QUESTION-BASED ONLY)
    if not question_token_data.empty and "question_type" in question_token_data.columns:
        if verbose:
            print("   🎯 Creating question type reasoning analysis (question-based)...")

        fig = create_question_type_reasoning_analysis(
            question_token_data, figsize=(16, 8)
        )

        # Apply no_titles if requested
        if no_titles:
            fig.suptitle("")
            for ax in fig.get_axes():
                ax.set_title("")

        filepath = f"{output_dir}/08_question_type_reasoning_question_based.png"
        title = (
            "Reasoning Effort by Question Type (Question-Based Tasks)"
            if not no_titles
            else None
        )
        save_plot(fig, filepath, title, no_titles=no_titles)
        results["generated_files"].append(filepath)

    # 2b. Reasoning Patterns Analysis (FULL OUTPUT ONLY)
    if not full_output_token_data.empty:
        if verbose:
            print("   🧩 Creating reasoning patterns analysis (full output)...")

        fig = create_full_output_reasoning_analysis(
            full_output_token_data, figsize=(16, 8), no_titles=no_titles
        )
        filepath = f"{output_dir}/08_reasoning_patterns_full_output.png"
        title = "Reasoning Patterns for Full Output Tasks" if not no_titles else None
        save_plot(fig, filepath, title, no_titles=no_titles)
        results["generated_files"].append(filepath)

    # 3. Token Efficiency by Model (ALL DATA)
    if verbose:
        print("   ⚡ Creating token efficiency analysis...")

    fig = create_token_efficiency_chart(
        token_data, figsize=(14, 8), no_titles=no_titles
    )
    filepath = f"{output_dir}/09_token_efficiency.png"
    title = "Token Efficiency Analysis (All Task Types)" if not no_titles else None
    save_plot(fig, filepath, title, no_titles=no_titles)
    results["generated_files"].append(filepath)

    return results


def create_full_output_reasoning_analysis(
    data: pd.DataFrame, figsize: Tuple[int, int] = (16, 8), no_titles: bool = False
) -> plt.Figure:
    """
    Create reasoning analysis specifically for full output tasks.
    Since full_output doesn't have question types, we analyze by task complexity.
    """
    if "benchmark" not in data.columns:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(
            0.5,
            0.5,
            "No task data available for full output reasoning analysis",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=14,
        )
        if not no_titles:
            ax.set_title("Reasoning Patterns for Full Output Tasks")
        return fig

    # Filter data with reasoning tokens
    reasoning_data = data[
        (data["reasoning_tokens"] > 0) & (data["reasoning_tokens"].notna())
    ].copy()

    if reasoning_data.empty:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(
            0.5,
            0.5,
            "No reasoning token data available for full output tasks",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=14,
        )
        if not no_titles:
            ax.set_title("Reasoning Patterns for Full Output Tasks")
        return fig

    # Create analysis by task complexity
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Left plot: Reasoning tokens by task
    tasks = sorted(reasoning_data["benchmark"].unique())
    reasoning_by_task = [
        reasoning_data[reasoning_data["benchmark"] == task]["reasoning_tokens"].values
        for task in tasks
    ]

    box_plot = ax1.boxplot(reasoning_by_task, labels=tasks, patch_artist=True)

    # Color the boxes
    colors = plt.cm.Set3(np.linspace(0, 1, len(tasks)))
    for patch, color in zip(box_plot["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax1.set_xlabel("Task")
    ax1.set_ylabel("Reasoning Tokens")
    if not no_titles:
        ax1.set_title("Reasoning Token Distribution by Task (Full Output)")
    ax1.tick_params(axis="x", rotation=45)

    # Right plot: Success rate by reasoning intensity
    reasoning_data["reasoning_category"] = pd.cut(
        reasoning_data["reasoning_tokens"],
        bins=[0, 500, 1000, 2000, float("inf")],
        labels=[
            "Low (0-500)",
            "Medium (500-1000)",
            "High (1000-2000)",
            "Very High (2000+)",
        ],
    )

    success_by_reasoning = (
        reasoning_data.groupby(["benchmark", "reasoning_category"])["correct"]
        .agg(["mean", "count"])
        .reset_index()
    )
    success_by_reasoning = success_by_reasoning[
        success_by_reasoning["count"] >= 3
    ]  # Filter for sufficient data

    # Create a pivot table for heatmap
    if not success_by_reasoning.empty and sns is not None:
        pivot_data = success_by_reasoning.pivot(
            index="benchmark", columns="reasoning_category", values="mean"
        )

        sns.heatmap(
            pivot_data,
            annot=True,
            fmt=".3f",
            cmap="YlOrRd",
            ax=ax2,
            cbar_kws={"label": "Success Rate"},
        )
        if not no_titles:
            ax2.set_title("Success Rate by Task and Reasoning Effort")
        ax2.set_xlabel("Reasoning Category")
        ax2.set_ylabel("Task")
    else:
        ax2.text(
            0.5,
            0.5,
            (
                "Insufficient data for heatmap"
                if sns is not None
                else "Seaborn not available"
            ),
            ha="center",
            va="center",
            transform=ax2.transAxes,
        )
        if not no_titles:
            ax2.set_title("Success Rate by Reasoning Effort")

    plt.tight_layout()
    return fig


def create_token_efficiency_chart(
    data: pd.DataFrame, figsize: Tuple[int, int] = (14, 8), no_titles: bool = False
) -> plt.Figure:
    """Create token efficiency analysis by model."""
    # Calculate efficiency metrics by model
    efficiency_stats = (
        data.groupby("model")
        .agg(
            {
                "total_tokens": "mean",
                "reasoning_tokens": "mean",
                "correct": ["mean", "sum", "count"],
                "efficiency_score": "mean",
            }
        )
        .round(3)
    )

    # Flatten column names
    efficiency_stats.columns = [
        "_".join(col).strip() for col in efficiency_stats.columns.values
    ]
    efficiency_stats = efficiency_stats.reset_index()

    # Calculate tokens per correct answer
    efficiency_stats["tokens_per_correct"] = efficiency_stats[
        "total_tokens_mean"
    ] / efficiency_stats["correct_mean"].replace(0, np.nan)

    # Sort by efficiency (lower tokens per correct answer = more efficient)
    efficiency_stats = efficiency_stats.sort_values("tokens_per_correct")

    # Apply display names
    efficiency_stats["display_name"] = efficiency_stats["model"].apply(
        lambda x: apply_display_names_to_list([x])[0]
    )

    # Create the plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Left plot: Tokens per correct answer
    bars1 = ax1.bar(
        range(len(efficiency_stats)),
        efficiency_stats["tokens_per_correct"],
        color="lightblue",
        alpha=0.7,
    )

    ax1.set_xlabel("Model")
    ax1.set_ylabel("Tokens per Correct Answer")
    if not no_titles:
        ax1.set_title("Token Efficiency by Model\n(Lower = More Efficient)")
    ax1.set_xticks(range(len(efficiency_stats)))
    ax1.set_xticklabels(efficiency_stats["display_name"], rotation=0, ha="right")

    # Add value labels
    for _, (bar, val) in enumerate(zip(bars1, efficiency_stats["tokens_per_correct"])):
        if not np.isnan(val):
            ax1.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + bar.get_height() * 0.01,
                f"{val:.0f}",
                ha="center",
                va="bottom",
                fontsize=9,
            )

    # Right plot: Average reasoning tokens vs accuracy
    ax2.scatter(
        efficiency_stats["reasoning_tokens_mean"],
        efficiency_stats["correct_mean"],
        s=efficiency_stats["correct_count"] * 2,  # Size by sample count
        alpha=0.7,
        c=range(len(efficiency_stats)),
        cmap="viridis",
    )

    # Add model labels with display names
    for _, row in efficiency_stats.iterrows():
        ax2.annotate(
            row["display_name"],
            (row["reasoning_tokens_mean"], row["correct_mean"]),
            xytext=(5, 5),
            textcoords="offset points",
            fontsize=9,
            alpha=0.8,
        )

    ax2.set_xlabel("Average Reasoning Tokens")
    ax2.set_ylabel("Accuracy")
    if not no_titles:
        ax2.set_title("Reasoning Tokens vs Accuracy\n(Bubble size = sample count)")
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    return fig


def generate_model_overview(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate overview visualizations focusing on overall model performance.
    (UNCHANGED - this already has separate full_output and question-based charts)
    """
    results = {"generated_files": [], "stats": {}}

    # Split data by question type
    full_output_data, question_data = split_by_question_type(data)

    # 1. Overall Model Performance (Primary metric: full_output)
    if not full_output_data.empty:
        if verbose:
            print("   📊 Creating overall model performance chart...")

        fig = create_model_performance_chart(
            full_output_data,
            "Overall Model Performance - Full Graph Output Tasks",
            ylabel="Accuracy",
            show_sample_counts=True,
            figsize=(14, 8),
            no_titles=no_titles,
        )

        filepath = f"{output_dir}/01_overall_model_performance.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Store stats
        model_stats = full_output_data.groupby("model")["correct"].agg(
            ["mean", "count"]
        )
        results["stats"]["full_output_by_model"] = model_stats.to_dict()

    # 2. Question Type Performance Summary
    if not question_data.empty:
        if verbose:
            print("   🎯 Creating question type performance overview...")

        fig = create_model_performance_chart(
            question_data,
            "Model Performance - Question-Based Tasks",
            ylabel="Accuracy",
            show_sample_counts=True,
            figsize=(14, 8),
            no_titles=no_titles,
        )

        filepath = f"{output_dir}/02_question_based_performance.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Store stats
        question_stats = question_data.groupby("model")["correct"].agg(
            ["mean", "count"]
        )
        results["stats"]["question_based_by_model"] = question_stats.to_dict()

    # 3. Input vs Output Comparison (for question-based tasks)
    if not question_data.empty and "target" in question_data.columns:
        targets = question_data["target"].unique()
        if len(targets) > 1:
            if verbose:
                print("   🔄 Creating input vs output comparison...")

            # Split by target
            data_by_target = {}
            for target in targets:
                data_by_target[target] = question_data[
                    question_data["target"] == target
                ]

            fig = create_comparison_chart(
                data_by_target,
                "Model Performance: Input vs Output Target Analysis (Question-Based Tasks)",
                comparison_labels=list(targets),
                ylabel="Accuracy",
                figsize=(14, 8),
                no_titles=no_titles,
            )

            filepath = f"{output_dir}/03_input_vs_output_comparison.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            target_stats = question_data.groupby(["target", "model"])["correct"].agg(
                ["mean", "count"]
            )
            results["stats"]["by_target"] = target_stats.to_dict()

    return results


def generate_question_overview(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate overview visualizations with separate full_output and question-based views.
    UPDATED to create separate system prompt and task performance charts.
    """
    results = {"generated_files": [], "stats": {}}

    # Split data by question type
    full_output_data, question_data = split_by_question_type(data)

    # 1. Question Type Breakdown (question-based only)
    if not question_data.empty and "question_type" in question_data.columns:
        if verbose:
            print("   📋 Creating question type breakdown...")

        fig = create_breakdown_chart(
            question_data,
            "question_type",
            "Performance Breakdown by Question Type (Question-Based Tasks)",
            ylabel="Accuracy",
            color_by="models",
            figsize=(16, 8),
            no_titles=no_titles,
        )

        filepath = f"{output_dir}/04_question_type_breakdown.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Store stats
        qtype_stats = question_data.groupby(["question_type", "model"])["correct"].agg(
            ["mean", "count"]
        )
        results["stats"]["by_question_type"] = qtype_stats.to_dict()

    # 2a. System Prompt Impact - Question-Based Tasks
    if not question_data.empty and "system_prompt" in question_data.columns:
        system_prompts = question_data["system_prompt"].unique()
        if len(system_prompts) > 1:
            if verbose:
                print(
                    "   🎭 Creating system prompt impact analysis (question-based)..."
                )

            fig = create_breakdown_chart(
                question_data,
                "system_prompt",
                "Impact of System Prompts on Question-Based Tasks",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
                no_titles=no_titles,
            )

            filepath = f"{output_dir}/05_system_prompt_impact_question_based.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            prompt_stats = question_data.groupby(["system_prompt", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["question_based_by_system_prompt"] = prompt_stats.to_dict()

    # 2b. System Prompt Impact - Full Output Tasks
    if not full_output_data.empty and "system_prompt" in full_output_data.columns:
        system_prompts = full_output_data["system_prompt"].unique()
        if len(system_prompts) > 1:
            if verbose:
                print("   🎭 Creating system prompt impact analysis (full output)...")

            fig = create_breakdown_chart(
                full_output_data,
                "system_prompt",
                "Impact of System Prompts on Full Output Tasks",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
                no_titles=no_titles,
            )

            filepath = f"{output_dir}/05_system_prompt_impact_full_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            prompt_stats = full_output_data.groupby(["system_prompt", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["full_output_by_system_prompt"] = prompt_stats.to_dict()

    # 3a. Task Performance Summary - Question-Based Tasks
    if not question_data.empty and "benchmark" in question_data.columns:
        if verbose:
            print("   🎯 Creating task performance summary (question-based)...")

        fig = create_breakdown_chart(
            question_data,
            "benchmark",
            "Question-Based Performance by Task",
            ylabel="Accuracy",
            color_by="models",
            figsize=(16, 10),
            no_titles=no_titles,
        )

        filepath = f"{output_dir}/06_task_performance_question_based.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Store stats
        task_stats = question_data.groupby(["benchmark", "model"])["correct"].agg(
            ["mean", "count"]
        )
        results["stats"]["question_based_by_task"] = task_stats.to_dict()

    # 3b. Task Performance Summary - Full Output Tasks
    if not full_output_data.empty and "benchmark" in full_output_data.columns:
        if verbose:
            print("   🎯 Creating task performance summary (full output)...")

        fig = create_breakdown_chart(
            full_output_data,
            "benchmark",
            "Full Output Performance by Task",
            ylabel="Accuracy",
            color_by="models",
            figsize=(16, 10),
            no_titles=no_titles,
        )

        filepath = f"{output_dir}/06_task_performance_full_output.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Store stats
        task_stats = full_output_data.groupby(["benchmark", "model"])["correct"].agg(
            ["mean", "count"]
        )
        results["stats"]["full_output_by_task"] = task_stats.to_dict()

    return results


def create_encoding_comparison_chart(
    data: pd.DataFrame, figsize=(16, 10), no_titles: bool = False
) -> plt.Figure:
    """
    Create a comprehensive comparison chart showing encoding effects across task types.
    UPDATED: Bottom legend, no top annotations, better y-axis spacing.
    """
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)

    # Split data
    full_output_data, question_data = split_by_question_type(data)

    # Get model ordering
    models = get_model_family_order(data["model"].unique())
    colors = get_color_palette(models, "models")
    display_names = apply_display_names_to_list(models)

    # 1. Full Output - Overall encoding performance
    if not full_output_data.empty:
        encoding_perf = (
            full_output_data.groupby(["encoding", "model"])["correct"]
            .mean()
            .unstack(fill_value=0)
        )
        encoding_perf = encoding_perf.reindex(columns=models, fill_value=0)

        x = np.arange(len(encoding_perf.index))
        width = 0.8 / len(models)

        legend_handles = []
        legend_labels = []

        for i, (model, display_name) in enumerate(zip(models, display_names)):
            if model in encoding_perf.columns:
                offset = (i - (len(models) - 1) / 2) * width
                bars = ax1.bar(
                    x + offset,
                    encoding_perf[model],
                    width,
                    label=display_name,
                    color=colors[model],
                    alpha=0.8,
                )

                # Store handles for legend (only from first subplot)
                if i == 0 or len(legend_handles) == 0:
                    legend_handles.extend(bars[:1])  # Just first bar
                    legend_labels.append(display_name)
                elif i < len(models):
                    legend_handles.append(bars[0])
                    legend_labels.append(display_name)

                # REMOVED: No value labels on top plots

        if not no_titles:
            ax1.set_title("Full Output Tasks")
        ax1.set_ylabel("Accuracy")
        ax1.set_xlabel("Encoding")
        ax1.set_xticks(x)
        ax1.set_xticklabels(encoding_perf.index)
        ax1.grid(True, axis="y", alpha=0.3)

    # 2. Question-Based - Overall encoding performance
    if not question_data.empty:
        encoding_perf = (
            question_data.groupby(["encoding", "model"])["correct"]
            .mean()
            .unstack(fill_value=0)
        )
        encoding_perf = encoding_perf.reindex(columns=models, fill_value=0)

        x = np.arange(len(encoding_perf.index))
        width = 0.8 / len(models)

        for i, (model, display_name) in enumerate(zip(models, display_names)):
            if model in encoding_perf.columns:
                offset = (i - (len(models) - 1) / 2) * width
                bars = ax2.bar(
                    x + offset,
                    encoding_perf[model],
                    width,
                    label=display_name,
                    color=colors[model],
                    alpha=0.8,
                )

                # REMOVED: No value labels on top plots

        if not no_titles:
            ax2.set_title("Question-Based Tasks")
        ax2.set_ylabel("Accuracy")
        ax2.set_xlabel("Encoding")
        ax2.set_xticks(x)
        ax2.set_xticklabels(encoding_perf.index)
        ax2.grid(True, axis="y", alpha=0.3)

    # 3. Encoding Preference by Model - Full Output Tasks
    full_output_preferences = []
    for model in models:
        model_data = (
            full_output_data[full_output_data["model"] == model]
            if not full_output_data.empty
            else pd.DataFrame()
        )
        if len(model_data) > 0:
            model_encoding_perf = model_data.groupby("encoding")["correct"].mean()
            if len(model_encoding_perf) > 1:
                best_encoding = model_encoding_perf.idxmax()
                preference_strength = (
                    model_encoding_perf.max() - model_encoding_perf.min()
                )
                full_output_preferences.append(
                    {
                        "model": model,
                        "display_name": apply_display_names_to_list([model])[0],
                        "preferred_encoding": best_encoding,
                        "preference_strength": preference_strength,
                    }
                )

    if full_output_preferences:
        pref_df = pd.DataFrame(full_output_preferences)

        # Create bar chart showing preference strength
        bars = ax3.bar(
            range(len(pref_df)),
            pref_df["preference_strength"],
            color=[colors[model] for model in pref_df["model"]],
            alpha=0.8,
        )

        if not no_titles:
            ax3.set_title("Encoding Preference Strength - Full Output")
        ax3.set_ylabel("Performance Difference")
        ax3.set_xlabel("Model")
        ax3.set_xticks(range(len(pref_df)))
        ax3.set_xticklabels(pref_df["display_name"], rotation=45, ha="center")

        # UPDATED: Increase y-limit to provide more space for annotations
        max_strength = pref_df["preference_strength"].max()
        ax3.set_ylim(0, max_strength * 1.3)

        # Add preference labels (only encoding names, no values inside bars)
        for i, (bar, pref) in enumerate(
            zip(bars, pref_df["preferred_encoding"])
        ):
            height = bar.get_height()
            # Show encoding preference at the top with more space
            ax3.text(
                bar.get_x() + bar.get_width() / 2,
                height + height * 0.1,
                f"{pref}",
                ha="center",
                va="bottom",
                fontsize=8,
                rotation=90,
                weight="bold",
            )

        ax3.grid(True, axis="y", alpha=0.3)
    else:
        ax3.text(
            0.5,
            0.5,
            "No encoding preferences found\n(insufficient full output data)",
            ha="center",
            va="center",
            transform=ax3.transAxes,
            fontsize=12,
        )
        if not no_titles:
            ax3.set_title("Encoding Preference Strength - Full Output")

    # 4. Encoding Preference by Model - Question-Based Tasks
    question_preferences = []
    for model in models:
        model_data = (
            question_data[question_data["model"] == model]
            if not question_data.empty
            else pd.DataFrame()
        )
        if len(model_data) > 0:
            model_encoding_perf = model_data.groupby("encoding")["correct"].mean()
            if len(model_encoding_perf) > 1:
                best_encoding = model_encoding_perf.idxmax()
                preference_strength = (
                    model_encoding_perf.max() - model_encoding_perf.min()
                )
                question_preferences.append(
                    {
                        "model": model,
                        "display_name": apply_display_names_to_list([model])[0],
                        "preferred_encoding": best_encoding,
                        "preference_strength": preference_strength,
                    }
                )

    if question_preferences:
        pref_df = pd.DataFrame(question_preferences)

        # Create bar chart showing preference strength
        bars = ax4.bar(
            range(len(pref_df)),
            pref_df["preference_strength"],
            color=[colors[model] for model in pref_df["model"]],
            alpha=0.8,
        )

        if not no_titles:
            ax4.set_title("Encoding Preference Strength - Question-Based")
        ax4.set_ylabel("Performance Difference")
        ax4.set_xlabel("Model")
        ax4.set_xticks(range(len(pref_df)))
        ax4.set_xticklabels(pref_df["display_name"], rotation=45, ha="center")

        # UPDATED: Increase y-limit to provide more space for annotations
        max_strength = pref_df["preference_strength"].max()
        ax4.set_ylim(0, max_strength * 1.3)

        # Add preference labels (only encoding names, no values inside bars)
        for i, (bar, pref) in enumerate(
            zip(bars, pref_df["preferred_encoding"])
        ):
            height = bar.get_height()
            # Show encoding preference at the top with more space
            ax4.text(
                bar.get_x() + bar.get_width() / 2,
                height + height * 0.1,
                f"{pref}",
                ha="center",
                va="bottom",
                fontsize=8,
                rotation=90,
                weight="bold",
            )

        ax4.grid(True, axis="y", alpha=0.3)
    else:
        ax4.text(
            0.5,
            0.5,
            "No encoding preferences found\n(insufficient question-based data)",
            ha="center",
            va="center",
            transform=ax4.transAxes,
            fontsize=12,
        )
        if not no_titles:
            ax4.set_title("Encoding Preference Strength - Question-Based")

    # UPDATED: Add legend at the bottom with 5 columns
    if legend_handles and legend_labels:
        fig.legend(
            legend_handles,
            legend_labels,
            loc='lower center',
            bbox_to_anchor=(0.5, -0.1),
            ncol=5,
            frameon=True,
            fancybox=True,
            shadow=True
        )

    # UPDATED: Adjust layout to make room for bottom legend
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    return fig


def calculate_encoding_comparison_stats(data: pd.DataFrame) -> Dict:
    """Calculate comprehensive encoding comparison statistics."""
    stats = {}

    # Overall encoding performance
    overall_encoding = (
        data.groupby("encoding")["correct"].agg(["mean", "count", "std"]).to_dict()
    )
    stats["overall_by_encoding"] = overall_encoding

    # Split data for task-specific analysis
    full_output_data, question_data = split_by_question_type(data)

    # Task type breakdown
    if not full_output_data.empty:
        full_output_encoding = (
            full_output_data.groupby("encoding")["correct"]
            .agg(["mean", "count"])
            .to_dict()
        )
        stats["full_output_by_encoding"] = full_output_encoding

    if not question_data.empty:
        question_encoding = (
            question_data.groupby("encoding")["correct"]
            .agg(["mean", "count"])
            .to_dict()
        )
        stats["question_based_by_encoding"] = question_encoding

    # Per-model encoding preferences by task type
    full_output_preferences = {}
    question_based_preferences = {}

    for model in data["model"].unique():
        # Full output preferences
        if not full_output_data.empty:
            model_full_data = full_output_data[full_output_data["model"] == model]
            if len(model_full_data) > 0:
                model_encoding_perf = model_full_data.groupby("encoding")[
                    "correct"
                ].mean()
                if len(model_encoding_perf) > 1:
                    best_encoding = model_encoding_perf.idxmax()
                    worst_encoding = model_encoding_perf.idxmin()
                    preference_strength = (
                        model_encoding_perf.max() - model_encoding_perf.min()
                    )

                    full_output_preferences[model] = {
                        "preferred_encoding": best_encoding,
                        "least_preferred_encoding": worst_encoding,
                        "preference_strength": preference_strength,
                        "performance_by_encoding": model_encoding_perf.to_dict(),
                    }

        # Question-based preferences
        if not question_data.empty:
            model_question_data = question_data[question_data["model"] == model]
            if len(model_question_data) > 0:
                model_encoding_perf = model_question_data.groupby("encoding")[
                    "correct"
                ].mean()
                if len(model_encoding_perf) > 1:
                    best_encoding = model_encoding_perf.idxmax()
                    worst_encoding = model_encoding_perf.idxmin()
                    preference_strength = (
                        model_encoding_perf.max() - model_encoding_perf.min()
                    )

                    question_based_preferences[model] = {
                        "preferred_encoding": best_encoding,
                        "least_preferred_encoding": worst_encoding,
                        "preference_strength": preference_strength,
                        "performance_by_encoding": model_encoding_perf.to_dict(),
                    }

    stats["full_output_model_preferences"] = full_output_preferences
    stats["question_based_model_preferences"] = question_based_preferences

    # Cross-task preference consistency
    consistency_analysis = {}
    for model in data["model"].unique():
        if model in full_output_preferences and model in question_based_preferences:
            full_pref = full_output_preferences[model]["preferred_encoding"]
            question_pref = question_based_preferences[model]["preferred_encoding"]
            consistent = full_pref == question_pref

            consistency_analysis[model] = {
                "full_output_preference": full_pref,
                "question_based_preference": question_pref,
                "consistent_preference": consistent,
                "full_output_strength": full_output_preferences[model][
                    "preference_strength"
                ],
                "question_based_strength": question_based_preferences[model][
                    "preference_strength"
                ],
            }

    stats["preference_consistency"] = consistency_analysis

    return stats


def create_system_prompt_comparison_chart(
    data: pd.DataFrame, figsize=(16, 10), no_titles: bool = False
) -> plt.Figure:
    """
    Create a comprehensive comparison chart showing system prompt effects across task types.
    UPDATED: Bottom legend, no top annotations, full names, no rotation, no middle values.
    """
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)

    # Split data
    full_output_data, question_data = split_by_question_type(data)

    # Get model ordering
    models = get_model_family_order(data["model"].unique())
    colors = get_color_palette(models, "models")
    display_names = apply_display_names_to_list(models)

    # 1. Full Output - Overall system prompt performance
    if not full_output_data.empty:
        prompt_perf = (
            full_output_data.groupby(["system_prompt", "model"])["correct"]
            .mean()
            .unstack(fill_value=0)
        )
        prompt_perf = prompt_perf.reindex(columns=models, fill_value=0)

        x = np.arange(len(prompt_perf.index))
        width = 0.8 / len(models)

        legend_handles = []
        legend_labels = []

        for i, (model, display_name) in enumerate(zip(models, display_names)):
            if model in prompt_perf.columns:
                offset = (i - (len(models) - 1) / 2) * width
                bars = ax1.bar(
                    x + offset,
                    prompt_perf[model],
                    width,
                    label=display_name,
                    color=colors[model],
                    alpha=0.8,
                )

                # Store handles for legend (only from first subplot)
                if i == 0 or len(legend_handles) == 0:
                    legend_handles.extend(bars[:1])
                    legend_labels.append(display_name)
                elif i < len(models):
                    legend_handles.append(bars[0])
                    legend_labels.append(display_name)

                # REMOVED: No value labels on top plots

        if not no_titles:
            ax1.set_title("Full Output Tasks")
        ax1.set_ylabel("Accuracy")
        ax1.set_xlabel("System Prompt")
        ax1.set_xticks(x)
        # UPDATED: No rotation, horizontal alignment
        ax1.set_xticklabels(prompt_perf.index, rotation=0, ha="center")
        ax1.grid(True, axis="y", alpha=0.3)

    # 2. Question-Based - Overall system prompt performance
    if not question_data.empty:
        prompt_perf = (
            question_data.groupby(["system_prompt", "model"])["correct"]
            .mean()
            .unstack(fill_value=0)
        )
        prompt_perf = prompt_perf.reindex(columns=models, fill_value=0)

        x = np.arange(len(prompt_perf.index))
        width = 0.8 / len(models)

        for i, (model, display_name) in enumerate(zip(models, display_names)):
            if model in prompt_perf.columns:
                offset = (i - (len(models) - 1) / 2) * width
                bars = ax2.bar(
                    x + offset,
                    prompt_perf[model],
                    width,
                    label=display_name,
                    color=colors[model],
                    alpha=0.8,
                )

                # REMOVED: No value labels on top plots

        if not no_titles:
            ax2.set_title("Question-Based Tasks")
        ax2.set_ylabel("Accuracy")
        ax2.set_xlabel("System Prompt")
        ax2.set_xticks(x)
        # UPDATED: No rotation, horizontal alignment
        ax2.set_xticklabels(prompt_perf.index, rotation=0, ha="center")
        ax2.grid(True, axis="y", alpha=0.3)

    # 3. System Prompt Preference by Model - Full Output Tasks
    full_output_preferences = []
    for model in models:
        model_data = (
            full_output_data[full_output_data["model"] == model]
            if not full_output_data.empty
            else pd.DataFrame()
        )
        if len(model_data) > 0:
            model_prompt_perf = model_data.groupby("system_prompt")["correct"].mean()
            if len(model_prompt_perf) > 1:
                best_prompt = model_prompt_perf.idxmax()
                preference_strength = model_prompt_perf.max() - model_prompt_perf.min()
                full_output_preferences.append(
                    {
                        "model": model,
                        "display_name": apply_display_names_to_list([model])[0],
                        "preferred_prompt": best_prompt,
                        "preference_strength": preference_strength,
                    }
                )

    if full_output_preferences:
        pref_df = pd.DataFrame(full_output_preferences)

        # Create bar chart showing preference strength
        bars = ax3.bar(
            range(len(pref_df)),
            pref_df["preference_strength"],
            color=[colors[model] for model in pref_df["model"]],
            alpha=0.8,
        )

        if not no_titles:
            ax3.set_title("System Prompt Preference Strength - Full Output")
        ax3.set_ylabel("Performance Difference")
        ax3.set_xlabel("Model")
        ax3.set_xticks(range(len(pref_df)))
        # UPDATED: No rotation, horizontal alignment
        ax3.set_xticklabels(pref_df["display_name"], rotation=0, ha="center")

        # UPDATED: Increase y-limit to provide more space for annotations
        max_strength = pref_df["preference_strength"].max()
        ax3.set_ylim(0, max_strength * 1.3)

        # Add preference labels (UPDATED: full names, no abbreviation, no values inside bars)
        for i, (bar, pref) in enumerate(
            zip(bars, pref_df["preferred_prompt"])
        ):
            height = bar.get_height()
            # UPDATED: Show full system prompt name, no abbreviation
            ax3.text(
                bar.get_x() + bar.get_width() / 2,
                height + height * 0.1,
                f"{pref}",
                ha="center",
                va="bottom",
                fontsize=8,
                rotation=90,
                weight="bold",
            )

        ax3.grid(True, axis="y", alpha=0.3)
    else:
        ax3.text(
            0.5,
            0.5,
            "No system prompt preferences found\n(insufficient full output data)",
            ha="center",
            va="center",
            transform=ax3.transAxes,
            fontsize=12,
        )
        if not no_titles:
            ax3.set_title("System Prompt Preference Strength - Full Output")

    # 4. System Prompt Preference by Model - Question-Based Tasks
    question_preferences = []
    for model in models:
        model_data = (
            question_data[question_data["model"] == model]
            if not question_data.empty
            else pd.DataFrame()
        )
        if len(model_data) > 0:
            model_prompt_perf = model_data.groupby("system_prompt")["correct"].mean()
            if len(model_prompt_perf) > 1:
                best_prompt = model_prompt_perf.idxmax()
                preference_strength = model_prompt_perf.max() - model_prompt_perf.min()
                question_preferences.append(
                    {
                        "model": model,
                        "display_name": apply_display_names_to_list([model])[0],
                        "preferred_prompt": best_prompt,
                        "preference_strength": preference_strength,
                    }
                )

    if question_preferences:
        pref_df = pd.DataFrame(question_preferences)

        # Create bar chart showing preference strength
        bars = ax4.bar(
            range(len(pref_df)),
            pref_df["preference_strength"],
            color=[colors[model] for model in pref_df["model"]],
            alpha=0.8,
        )

        if not no_titles:
            ax4.set_title("System Prompt Preference Strength - Question-Based")
        ax4.set_ylabel("Performance Difference")
        ax4.set_xlabel("Model")
        ax4.set_xticks(range(len(pref_df)))
        # UPDATED: No rotation, horizontal alignment
        ax4.set_xticklabels(pref_df["display_name"], rotation=0, ha="center")

        # UPDATED: Increase y-limit to provide more space for annotations
        max_strength = pref_df["preference_strength"].max()
        ax4.set_ylim(0, max_strength * 1.3)

        # Add preference labels (UPDATED: full names, no abbreviation, no values inside bars)
        for i, (bar, pref) in enumerate(
            zip(bars, pref_df["preferred_prompt"])
        ):
            height = bar.get_height()
            # UPDATED: Show full system prompt name, no abbreviation
            ax4.text(
                bar.get_x() + bar.get_width() / 2,
                height + height * 0.1,
                f"{pref}",
                ha="center",
                va="bottom",
                fontsize=8,
                rotation=90,
                weight="bold",
            )

        ax4.grid(True, axis="y", alpha=0.3)
    else:
        ax4.text(
            0.5,
            0.5,
            "No system prompt preferences found\n(insufficient question-based data)",
            ha="center",
            va="center",
            transform=ax4.transAxes,
            fontsize=12,
        )
        if not no_titles:
            ax4.set_title("System Prompt Preference Strength - Question-Based")

    # UPDATED: Add legend at the bottom with 5 columns
    if legend_handles and legend_labels:
        fig.legend(
            legend_handles,
            legend_labels,
            loc='lower center',
            bbox_to_anchor=(0.5, -0.1),
            ncol=5,
            frameon=True,
            fancybox=True,
            shadow=True
        )

    # UPDATED: Adjust layout to make room for bottom legend
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    return fig


def calculate_system_prompt_comparison_stats(data: pd.DataFrame) -> Dict:
    """Calculate comprehensive system prompt comparison statistics."""
    stats = {}

    # Overall system prompt performance
    overall_prompt = (
        data.groupby("system_prompt")["correct"].agg(["mean", "count", "std"]).to_dict()
    )
    stats["overall_by_system_prompt"] = overall_prompt

    # Split data for task-specific analysis
    full_output_data, question_data = split_by_question_type(data)

    # Task type breakdown
    if not full_output_data.empty:
        full_output_prompt = (
            full_output_data.groupby("system_prompt")["correct"]
            .agg(["mean", "count"])
            .to_dict()
        )
        stats["full_output_by_system_prompt"] = full_output_prompt

    if not question_data.empty:
        question_prompt = (
            question_data.groupby("system_prompt")["correct"]
            .agg(["mean", "count"])
            .to_dict()
        )
        stats["question_based_by_system_prompt"] = question_prompt

    # Per-model system prompt preferences by task type
    full_output_preferences = {}
    question_based_preferences = {}

    for model in data["model"].unique():
        # Full output preferences
        if not full_output_data.empty:
            model_full_data = full_output_data[full_output_data["model"] == model]
            if len(model_full_data) > 0:
                model_prompt_perf = model_full_data.groupby("system_prompt")[
                    "correct"
                ].mean()
                if len(model_prompt_perf) > 1:
                    best_prompt = model_prompt_perf.idxmax()
                    worst_prompt = model_prompt_perf.idxmin()
                    preference_strength = (
                        model_prompt_perf.max() - model_prompt_perf.min()
                    )

                    full_output_preferences[model] = {
                        "preferred_system_prompt": best_prompt,
                        "least_preferred_system_prompt": worst_prompt,
                        "preference_strength": preference_strength,
                        "performance_by_system_prompt": model_prompt_perf.to_dict(),
                    }

        # Question-based preferences
        if not question_data.empty:
            model_question_data = question_data[question_data["model"] == model]
            if len(model_question_data) > 0:
                model_prompt_perf = model_question_data.groupby("system_prompt")[
                    "correct"
                ].mean()
                if len(model_prompt_perf) > 1:
                    best_prompt = model_prompt_perf.idxmax()
                    worst_prompt = model_prompt_perf.idxmin()
                    preference_strength = (
                        model_prompt_perf.max() - model_prompt_perf.min()
                    )

                    question_based_preferences[model] = {
                        "preferred_system_prompt": best_prompt,
                        "least_preferred_system_prompt": worst_prompt,
                        "preference_strength": preference_strength,
                        "performance_by_system_prompt": model_prompt_perf.to_dict(),
                    }

    stats["full_output_model_preferences"] = full_output_preferences
    stats["question_based_model_preferences"] = question_based_preferences

    # Cross-task preference consistency
    consistency_analysis = {}
    for model in data["model"].unique():
        if model in full_output_preferences and model in question_based_preferences:
            full_pref = full_output_preferences[model]["preferred_system_prompt"]
            question_pref = question_based_preferences[model]["preferred_system_prompt"]
            consistent = full_pref == question_pref

            consistency_analysis[model] = {
                "full_output_preference": full_pref,
                "question_based_preference": question_pref,
                "consistent_preference": consistent,
                "full_output_strength": full_output_preferences[model][
                    "preference_strength"
                ],
                "question_based_strength": question_based_preferences[model][
                    "preference_strength"
                ],
            }

    stats["preference_consistency"] = consistency_analysis

    return stats


def generate_encoding_impact_overview(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate encoding impact visualizations for both full_output and question-based tasks.

    Parameters:
    - data: Full evaluation dataset
    - output_dir: Directory to save visualizations
    - verbose: Whether to print progress info
    - no_titles: Whether to suppress titles

    Returns:
    - Dict with generation summary
    """
    results = {"generated_files": [], "stats": {}}

    if verbose:
        print("   📝 Analyzing encoding impact...")

    # Split data by question type
    full_output_data, question_data = split_by_question_type(data)

    # 1. Encoding Impact - Full Output Tasks
    if not full_output_data.empty and "encoding" in full_output_data.columns:
        encodings = full_output_data["encoding"].unique()
        if len(encodings) > 1:
            if verbose:
                print("   📊 Creating encoding impact analysis (full output)...")

            fig = create_breakdown_chart(
                full_output_data,
                "encoding",
                "Impact of Graph Encoding on Full Output Tasks",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
                no_titles=no_titles,
            )

            filepath = f"{output_dir}/10_encoding_impact_full_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            encoding_stats = full_output_data.groupby(["encoding", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["full_output_by_encoding"] = encoding_stats.to_dict()

    # 2. Encoding Impact - Question-Based Tasks
    if not question_data.empty and "encoding" in question_data.columns:
        encodings = question_data["encoding"].unique()
        if len(encodings) > 1:
            if verbose:
                print("   📊 Creating encoding impact analysis (question-based)...")

            fig = create_breakdown_chart(
                question_data,
                "encoding",
                "Impact of Graph Encoding on Question-Based Tasks",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
                no_titles=no_titles,
            )

            filepath = f"{output_dir}/11_encoding_impact_question_based.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            encoding_stats = question_data.groupby(["encoding", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["question_based_by_encoding"] = encoding_stats.to_dict()

    # 3. Encoding Comparison Analysis (Side-by-side comparison)
    if (
        not full_output_data.empty
        and not question_data.empty
        and "encoding" in data.columns
        and len(data["encoding"].unique()) > 1
    ):

        if verbose:
            print("   🔄 Creating encoding comparison analysis...")

        fig = create_encoding_comparison_chart(
            data, figsize=(16, 10), no_titles=no_titles
        )
        filepath = f"{output_dir}/12_encoding_comparison_analysis.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Calculate comparison stats
        comparison_stats = calculate_encoding_comparison_stats(data)
        results["stats"]["encoding_comparison"] = comparison_stats

    # 4. System Prompt Comparison Analysis (NEW)
    if (
        not full_output_data.empty
        and not question_data.empty
        and "system_prompt" in data.columns
        and len(data["system_prompt"].unique()) > 1
    ):

        if verbose:
            print("   🎭 Creating system prompt comparison analysis...")

        fig = create_system_prompt_comparison_chart(
            data, figsize=(16, 10), no_titles=no_titles
        )
        filepath = f"{output_dir}/13_system_prompt_comparison_analysis.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Calculate comparison stats
        system_prompt_stats = calculate_system_prompt_comparison_stats(data)
        results["stats"]["system_prompt_comparison"] = system_prompt_stats

    return results


def generate_overview_visualizations(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate all overview visualizations.
    UPDATED with new separate full_output and question-based views and encoding analysis.
    """
    if verbose:
        print("🎯 Generating Level 1 Overview Visualizations...")

    # Create overview subdirectory
    overview_dir = os.path.join(output_dir, "overview")
    os.makedirs(overview_dir, exist_ok=True)

    all_results = {
        "overview_dir": overview_dir,
        "model_overview": {},
        "question_overview": {},
        "token_overview": {},
        "encoding_overview": {},  # NEW
        "total_files": 0,
    }

    # Generate model overview (unchanged)
    if verbose:
        print("  📊 Model Performance Overview...")
    model_results = generate_model_overview(data, overview_dir, verbose, no_titles)
    all_results["model_overview"] = model_results

    # Generate enhanced question overview with separate views
    if verbose:
        print("  🎯 Question Type Overview...")
    question_results = generate_question_overview(
        data, overview_dir, verbose, no_titles
    )
    all_results["question_overview"] = question_results

    # Generate enhanced token overview
    if verbose:
        print("  🧠 Token Usage Overview...")
    token_results = generate_token_overview(data, overview_dir, verbose, no_titles)
    all_results["token_overview"] = token_results

    # Generate NEW encoding impact overview
    if verbose:
        print("  📝 Encoding Impact Overview...")
    encoding_results = generate_encoding_impact_overview(
        data, overview_dir, verbose, no_titles
    )
    all_results["encoding_overview"] = encoding_results

    # Summary
    total_files = (
        len(model_results["generated_files"])
        + len(question_results["generated_files"])
        + len(token_results["generated_files"])
        + len(encoding_results["generated_files"])  # NEW
    )
    all_results["total_files"] = total_files

    if verbose:
        print(f"  ✅ Generated {total_files} overview visualizations in {overview_dir}")

    return all_results
