"""
Detailed Generator - Level 2 visualizations for deep-dive analysis.

UPDATED to provide separate views for full_output and question-based prompts in per-task analysis.
"""

import os
from typing import Dict, Any
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scripts.visualization.core.chart_builders import (
    create_comparison_chart,
    create_breakdown_chart,
)
from scripts.visualization.core.data_loader import split_by_question_type
from scripts.visualization.core.utils import (
    save_plot,
    get_color_palette,
    get_model_family_order,
    get_size_pattern_order,
)
from scripts.evaluate_responses import get_ground_truth_answer, compare_answers


def generate_input_output_answer_transfer_analysis(
    data: pd.DataFrame, task_name: str, output_dir: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Analyze cases where a model gives the wrong answer for an input-targeted question,
    but that answer would be correct for the corresponding output graph.
    (UNCHANGED - this only applies to question-based data anyway)
    """
    results = {"generated_files": [], "stats": {}, "task_name": task_name}

    if verbose:
        print(f"   🔄 Analyzing Input-Output Answer Transfer for {task_name}...")

    # Check all models in the dataset first for comprehensive reporting
    all_models_in_task = (
        sorted(data["model"].unique()) if "model" in data.columns else []
    )

    # Filter to incorrect input-targeted question-based responses
    transfer_candidates = data[
        (data["question_type"] != "full_output")
        & (data["target"] == "input")
        & (data["correct"] == False)
    ].copy()

    # Case 1: Perfect performance by all models
    if transfer_candidates.empty:
        if verbose:
            print(f"   ✅ Perfect input performance by all models on {task_name}!")
            print(f"   📊 Models analyzed: {', '.join(all_models_in_task)}")

        # Still create a visualization showing this excellent result
        fig, ax = plt.subplots(figsize=(12, 6))
        ax.text(
            0.5,
            0.5,
            f"🎯 Perfect Input Performance\n\n"
            f"All {len(all_models_in_task)} models correctly answered\n"
            f"every input-targeted question for {task_name}\n\n"
            f"Models: {', '.join(all_models_in_task[:3])}{'...' if len(all_models_in_task) > 3 else ''}",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=14,
            bbox=dict(boxstyle="round,pad=1", facecolor="lightgreen", alpha=0.8),
        )
        ax.set_title(
            f"{task_name} - Input-Output Answer Transfer Analysis", fontsize=14, pad=20
        )
        ax.axis("off")  # Hide axes for this message-only plot

        filepath = f"{output_dir}/{task_name}_input_output_transfer.png"
        save_plot(fig, filepath)
        results["generated_files"].append(filepath)

        results["stats"] = {
            "perfect_performance": True,
            "all_models": all_models_in_task,
            "total_models": len(all_models_in_task),
            "message": "All models achieved perfect input performance",
        }
        return results

    # Continue with existing analysis logic for cases with incorrect responses...
    # [Rest of the function remains the same as in the previous implementation]
    models_with_errors = sorted(transfer_candidates["model"].unique())
    models_with_perfect_input = [
        m for m in all_models_in_task if m not in models_with_errors
    ]

    if verbose:
        print(
            f"   📊 Found {len(transfer_candidates)} incorrect input responses from {len(models_with_errors)} models"
        )
        if models_with_perfect_input:
            print(
                f"   ✅ Perfect input performance: {', '.join(models_with_perfect_input)}"
            )

        # Show breakdown by model
        model_breakdown = transfer_candidates["model"].value_counts().sort_index()
        print("   📋 Incorrect responses by model:")
        for model, count in model_breakdown.items():
            print(f"      {model}: {count}")

    # Analyze transfer cases (same as before, but enhanced)
    transfer_cases = []
    debug_stats = {
        "total_candidates": len(transfer_candidates),
        "missing_answer": 0,
        "empty_answer": 0,
        "missing_output_file": 0,
        "computation_errors": 0,
        "valid_cases": 0,
        "transfer_cases": 0,
        "models_with_perfect_input": models_with_perfect_input,
        "models_with_errors": models_with_errors,
    }

    # [Same analysis loop as before - omitted for brevity, but includes all the enhanced error handling]
    for _, row in transfer_candidates.iterrows():
        try:
            # Enhanced model answer extraction
            model_answer = ""
            try:
                details = row.get("details", {})
                if isinstance(details, dict):
                    response_metadata = details.get("response_metadata", {})
                    if isinstance(response_metadata, dict):
                        raw_answer = response_metadata.get("answer", "")
                        if raw_answer is not None:
                            model_answer = str(raw_answer).strip()

                if not model_answer:
                    debug_stats["empty_answer"] += 1
                    continue

            except Exception:
                debug_stats["missing_answer"] += 1
                continue

            # Construct paths and analyze (same as before)
            ground_truth_input_path = row.get("ground_truth_path", "")
            if not ground_truth_input_path:
                continue

            ground_truth_output_path = ground_truth_input_path.replace(
                "/input/", "/output/"
            )

            if not os.path.exists(ground_truth_output_path):
                debug_stats["missing_output_file"] += 1
                continue

            try:
                correct_output_answer = get_ground_truth_answer(
                    ground_truth_output_path, row["question_type"], "output"
                )

                is_transfer_case = compare_answers(
                    correct_output_answer, model_answer, row["question_type"]
                )

                debug_stats["valid_cases"] += 1
                if is_transfer_case:
                    debug_stats["transfer_cases"] += 1

                # Get correct input answer
                correct_input_answer = ""
                try:
                    ground_truth_metadata = row.get("details", {}).get(
                        "ground_truth_metadata", {}
                    )
                    if isinstance(ground_truth_metadata, dict):
                        correct_input_answer = str(
                            ground_truth_metadata.get("answer", "")
                        ).strip()
                except:
                    pass

                transfer_cases.append(
                    {
                        "model": row["model"],
                        "question_type": row["question_type"],
                        "encoding": row.get("encoding", "unknown"),
                        "system_prompt": row.get("system_prompt", "unknown"),
                        "size": row.get("size", "unknown"),
                        "model_answer": model_answer,
                        "correct_input_answer": correct_input_answer,
                        "correct_output_answer": correct_output_answer,
                        "is_transfer_case": is_transfer_case,
                        "response_path": row.get("response_path", ""),
                    }
                )

            except Exception:
                debug_stats["computation_errors"] += 1
                continue

        except Exception:
            continue

    # Enhanced reporting
    if verbose:
        print("   🔍 Analysis Results:")
        print(f"      Total incorrect responses: {debug_stats['total_candidates']}")
        print(f"      Valid answers extracted: {debug_stats['valid_cases']}")
        print(f"      Transfer cases found: {debug_stats['transfer_cases']}")
        if debug_stats["missing_answer"] + debug_stats["empty_answer"] > 0:
            print(
                f"      ⚠️ Answer extraction issues: {debug_stats['missing_answer'] + debug_stats['empty_answer']}"
            )

    # Case 3: We have analysis data - create statistics and visualization
    if transfer_cases:
        transfer_df = pd.DataFrame(transfer_cases)

        try:
            transfer_stats = (
                transfer_df.groupby(["model", "question_type"])["is_transfer_case"]
                .agg(["sum", "count", "mean"])
                .reset_index()
            )
            transfer_stats.columns = [
                "model",
                "question_type",
                "transfer_count",
                "total_incorrect",
                "transfer_rate",
            ]

            if verbose:
                total_transfer_cases = int(transfer_stats["transfer_count"].sum())
                total_analyzed = int(transfer_stats["total_incorrect"].sum())
                overall_transfer_rate = (
                    total_transfer_cases / total_analyzed if total_analyzed > 0 else 0
                )
                print(
                    f"   📈 Overall: {total_transfer_cases}/{total_analyzed} transfer cases ({overall_transfer_rate:.1%})"
                )

            # Always create visualization - the improved chart handles zero cases well
            fig = create_input_output_transfer_chart(
                transfer_stats,
                f"{task_name} - Input-Output Answer Transfer Analysis",
                figsize=(14, 8),
            )

            filepath = f"{output_dir}/{task_name}_input_output_transfer.png"
            save_plot(fig, filepath)
            results["generated_files"].append(filepath)

            # Comprehensive stats
            results["stats"] = {
                "total_candidates": len(transfer_candidates),
                "valid_cases_analyzed": len(transfer_cases),
                "transfer_cases_found": int(transfer_stats["transfer_count"].sum()),
                "overall_transfer_rate": (
                    float(transfer_stats["transfer_count"].sum() / len(transfer_cases))
                    if len(transfer_cases) > 0
                    else 0
                ),
                "by_model_question": transfer_stats.to_dict("records"),
                "debug_stats": debug_stats,
                "models_analyzed": sorted(transfer_df["model"].unique()),
                "models_with_perfect_input": models_with_perfect_input,
                "all_models_in_task": all_models_in_task,
                "has_zero_transfer_cases": int(transfer_stats["transfer_count"].sum())
                == 0,
                "perfect_input_output_distinction": int(
                    transfer_stats["transfer_count"].sum()
                )
                == 0,
            }

        except Exception as e:
            if verbose:
                print(f"   ❌ Error creating statistics: {e}")
            results["stats"] = {
                "stats_creation_error": str(e),
                "debug_stats": debug_stats,
            }

    else:
        # Case 4: No valid cases could be analyzed
        if verbose:
            print(f"   ⚠️ No valid transfer cases could be analyzed for {task_name}")

        results["stats"] = {
            "no_valid_cases": True,
            "debug_stats": debug_stats,
            "models_with_perfect_input": models_with_perfect_input,
            "all_models_in_task": all_models_in_task,
        }

    return results


def create_input_output_transfer_chart(
    transfer_stats: pd.DataFrame, title: str, figsize=(14, 8)
) -> plt.Figure:
    """
    Create a chart showing Input-Output Answer Transfer rates by model and question type.
    (UNCHANGED from previous version)
    """
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Handle completely empty case
    if transfer_stats.empty:
        ax.text(
            0.5,
            0.5,
            "No transfer cases found\n(Perfect input-output distinction by all models)",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=16,
            bbox=dict(boxstyle="round,pad=1", facecolor="lightgreen", alpha=0.7),
        )
        ax.set_title(title, fontsize=14, pad=20)
        ax.set_ylabel("Input-Output Answer Transfer Rate")
        ax.set_xlabel("Question Type")
        ax.set_ylim(0, 1)
        return fig

    # Get unique values and apply proper ordering
    question_types = sorted(transfer_stats["question_type"].unique())
    models = get_model_family_order(transfer_stats["model"].unique())

    # Set up bar positions
    x = np.arange(len(question_types))
    width = 0.8 / len(models) if len(models) > 0 else 0.8

    # Get colors for models
    colors = get_color_palette(models, "models")

    max_rate = transfer_stats["transfer_rate"].max() if not transfer_stats.empty else 0
    has_any_nonzero = max_rate > 0

    # Determine minimum bar height for zero cases
    min_bar_height = (
        0.02 if not has_any_nonzero else 0
    )  # Show zero bars if all are zero

    # Create bars for each model
    for i, model in enumerate(models):
        model_data = transfer_stats[transfer_stats["model"] == model]

        # Align with question types
        transfer_rates = []
        sample_counts = []
        display_heights = []  # What we actually draw

        for qtype in question_types:
            qtype_data = model_data[model_data["question_type"] == qtype]
            if not qtype_data.empty:
                rate = qtype_data["transfer_rate"].iloc[0]
                count = qtype_data["total_incorrect"].iloc[0]
                transfer_rates.append(rate)
                sample_counts.append(count)

                # Use minimum height for zero rates when all rates are zero
                display_height = (
                    max(rate, min_bar_height)
                    if rate == 0 and not has_any_nonzero
                    else rate
                )
                display_heights.append(display_height)
            else:
                transfer_rates.append(0)
                sample_counts.append(0)
                display_heights.append(0)

        offset = (i - (len(models) - 1) / 2) * width
        bars = ax.bar(
            x + offset,
            display_heights,
            width,
            label=model,
            color=colors[model],
            alpha=0.8,
        )

        # Add labels - IMPROVED to show zero cases clearly
        for bar, rate, count, display_height in zip(
            bars, transfer_rates, sample_counts, display_heights
        ):
            if count > 0:  # Only show labels where we have sample data
                bar_center_x = bar.get_x() + bar.get_width() / 2.0

                if rate == 0:
                    # Special handling for zero rates
                    label_y = display_height + (0.01 if has_any_nonzero else 0.005)
                    ax.text(
                        bar_center_x,
                        label_y,
                        f"0%\n(n={count})",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                        bbox=dict(
                            boxstyle="round,pad=0.2",
                            facecolor="white",
                            alpha=0.8,
                            edgecolor="gray",
                        ),
                    )
                else:
                    # Normal case for non-zero rates
                    label_y = display_height + max(0.01, max_rate * 0.01)
                    ax.text(
                        bar_center_x,
                        label_y,
                        f"{rate:.1%}\n(n={count})",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )

    # Customize chart
    ax.set_title(title, fontsize=14, pad=20)
    ax.set_ylabel("Input-Output Answer Transfer Rate")
    ax.set_xlabel("Question Type")

    # Set y-axis limit with appropriate scaling
    if not has_any_nonzero:
        # All zero case - show 0% to 5% range to make zero bars visible
        ax.set_ylim(0, 0.05)
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))

        # Add prominent text explaining the zero results
        ax.text(
            0.5,
            0.8,
            "🎯 Excellent Performance: No Input-Output Confusion Found",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=12,
            weight="bold",
            bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgreen", alpha=0.8),
        )
    else:
        # Normal case with some non-zero values
        y_max = max(1.0, max_rate * 1.15)
        ax.set_ylim(0, y_max)
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}"))

    ax.grid(True, axis="y", alpha=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels(question_types, rotation=0, ha="right")
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    # Add explanation text - IMPROVED with different messages for zero vs non-zero cases
    # if not has_any_nonzero:
    #    explanation = ("Zero transfer rate means models correctly distinguish between input and output graphs.\n"
    #                  "When models got input questions wrong, their incorrect answers were never\n"
    #                  "coincidentally correct for the corresponding output graphs.")
    # else:
    #    explanation = ("Shows percentage of incorrect input-targeted responses where the wrong answer\n"
    #                  "would have been correct if asked about the output graph instead.\n"
    #                  "Higher rates suggest input-output confusion.")
    #
    # ax.text(0.02, 0.02, explanation, transform=ax.transAxes,
    #        fontsize=9, verticalalignment='bottom',
    #        bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.7))

    return fig


def create_size_pattern_comparison_chart(
    data: pd.DataFrame, task_name: str, figsize=(16, 10)
) -> plt.Figure:
    """
    Create a comprehensive comparison chart showing size pattern effects across task types.
    UPDATED to use SIZE_PATTERNS definition order.
    """
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)

    # Split data
    full_output_data, question_data = split_by_question_type(data)

    # Get model ordering and colors
    models = get_model_family_order(data["model"].unique())
    colors = get_color_palette(models, "models")

    # 1. Full Output - Performance by size pattern
    if not full_output_data.empty:
        pattern_perf = (
            full_output_data.groupby(["size_pattern", "model"])["correct"]
            .mean()
            .unstack(fill_value=0)
        )
        pattern_perf = pattern_perf.reindex(columns=models, fill_value=0)

        # UPDATED: Use SIZE_PATTERNS definition order
        available_patterns = list(pattern_perf.index)
        ordered_patterns = get_size_pattern_order(available_patterns)
        pattern_perf = pattern_perf.reindex(index=ordered_patterns)

        x = np.arange(len(pattern_perf.index))
        width = 0.8 / len(models)

        for i, model in enumerate(models):
            if model in pattern_perf.columns:
                offset = (i - (len(models) - 1) / 2) * width
                bars = ax1.bar(
                    x + offset,
                    pattern_perf[model],
                    width,
                    label=model,
                    color=colors[model],
                    alpha=0.8,
                )

                # Add value labels
                for bar, val in zip(bars, pattern_perf[model]):
                    if val > 0:
                        ax1.text(
                            bar.get_x() + bar.get_width() / 2,
                            bar.get_height() + 0.01,
                            f"{val:.3f}",
                            ha="center",
                            va="bottom",
                            fontsize=8,
                        )

        ax1.set_title(f"{task_name} - Full Output Tasks")
        ax1.set_ylabel("Accuracy")
        ax1.set_xlabel("Size Pattern")
        ax1.set_xticks(x)
        ax1.set_xticklabels(pattern_perf.index, rotation=0, ha="right")
        ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        ax1.grid(True, axis="y", alpha=0.3)

    # 2. Question-Based - Performance by size pattern
    if not question_data.empty:
        pattern_perf = (
            question_data.groupby(["size_pattern", "model"])["correct"]
            .mean()
            .unstack(fill_value=0)
        )
        pattern_perf = pattern_perf.reindex(columns=models, fill_value=0)

        # UPDATED: Use SIZE_PATTERNS definition order
        available_patterns = list(pattern_perf.index)
        ordered_patterns = get_size_pattern_order(available_patterns)
        pattern_perf = pattern_perf.reindex(index=ordered_patterns)

        x = np.arange(len(pattern_perf.index))
        width = 0.8 / len(models)

        for i, model in enumerate(models):
            if model in pattern_perf.columns:
                offset = (i - (len(models) - 1) / 2) * width
                bars = ax2.bar(
                    x + offset,
                    pattern_perf[model],
                    width,
                    label=model,
                    color=colors[model],
                    alpha=0.8,
                )

                # Add value labels
                for bar, val in zip(bars, pattern_perf[model]):
                    if val > 0:
                        ax2.text(
                            bar.get_x() + bar.get_width() / 2,
                            bar.get_height() + 0.01,
                            f"{val:.3f}",
                            ha="center",
                            va="bottom",
                            fontsize=8,
                        )

        ax2.set_title(f"{task_name} - Question-Based Tasks")
        ax2.set_ylabel("Accuracy")
        ax2.set_xlabel("Size Pattern")
        ax2.set_xticks(x)
        ax2.set_xticklabels(pattern_perf.index, rotation=0, ha="right")
        ax2.grid(True, axis="y", alpha=0.3)

    # 3. Performance Drop Analysis (Full Output)
    if not full_output_data.empty:
        # Calculate performance drop from smallest to largest pattern for each model
        pattern_drops = []

        for model in models:
            model_data = full_output_data[full_output_data["model"] == model]
            if not model_data.empty:
                model_pattern_perf = model_data.groupby("size_pattern")[
                    "correct"
                ].mean()

                # Find the highest and lowest performing patterns
                if len(model_pattern_perf) > 1:
                    max_perf = model_pattern_perf.max()
                    min_perf = model_pattern_perf.min()
                    drop = max_perf - min_perf

                    pattern_drops.append(
                        {
                            "model": model,
                            "performance_drop": drop,
                            "best_pattern": model_pattern_perf.idxmax(),
                            "worst_pattern": model_pattern_perf.idxmin(),
                        }
                    )

        if pattern_drops:
            drop_df = pd.DataFrame(pattern_drops)

            bars = ax3.bar(
                range(len(drop_df)),
                drop_df["performance_drop"],
                color=[colors[model] for model in drop_df["model"]],
                alpha=0.8,
            )

            ax3.set_title(f"{task_name} - Pattern Sensitivity (Full Output)")
            ax3.set_ylabel("Performance Drop (Best - Worst)")
            ax3.set_xlabel("Model")
            ax3.set_xticks(range(len(drop_df)))
            ax3.set_xticklabels(drop_df["model"], rotation=0, ha="right")

            # Add labels showing best and worst patterns
            for i, (bar, row) in enumerate(zip(bars, drop_df.itertuples())):
                height = bar.get_height()
                if height > 0.01:  # Only show if meaningful drop
                    ax3.text(
                        bar.get_x() + bar.get_width() / 2,
                        height + 0.01,
                        f"{height:.3f}",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                        weight="bold",
                    )
                    # Show best/worst patterns
                    ax3.text(
                        bar.get_x() + bar.get_width() / 2,
                        height / 2,
                        f"↑{row.best_pattern}\n↓{row.worst_pattern}",
                        ha="center",
                        va="center",
                        fontsize=6,
                        color="white",
                        weight="bold",
                    )

            ax3.grid(True, axis="y", alpha=0.3)
        else:
            ax3.text(
                0.5,
                0.5,
                "Insufficient pattern data\nfor sensitivity analysis",
                ha="center",
                va="center",
                transform=ax3.transAxes,
                fontsize=12,
            )
            ax3.set_title(f"{task_name} - Pattern Sensitivity (Full Output)")

    # 4. Performance Drop Analysis (Question-Based)
    if not question_data.empty:
        # Same analysis for question-based tasks
        pattern_drops = []

        for model in models:
            model_data = question_data[question_data["model"] == model]
            if not model_data.empty:
                model_pattern_perf = model_data.groupby("size_pattern")[
                    "correct"
                ].mean()

                if len(model_pattern_perf) > 1:
                    max_perf = model_pattern_perf.max()
                    min_perf = model_pattern_perf.min()
                    drop = max_perf - min_perf

                    pattern_drops.append(
                        {
                            "model": model,
                            "performance_drop": drop,
                            "best_pattern": model_pattern_perf.idxmax(),
                            "worst_pattern": model_pattern_perf.idxmin(),
                        }
                    )

        if pattern_drops:
            drop_df = pd.DataFrame(pattern_drops)

            bars = ax4.bar(
                range(len(drop_df)),
                drop_df["performance_drop"],
                color=[colors[model] for model in drop_df["model"]],
                alpha=0.8,
            )

            ax4.set_title(f"{task_name} - Pattern Sensitivity (Question-Based)")
            ax4.set_ylabel("Performance Drop (Best - Worst)")
            ax4.set_xlabel("Model")
            ax4.set_xticks(range(len(drop_df)))
            ax4.set_xticklabels(drop_df["model"], rotation=0, ha="right")

            # Add labels
            for i, (bar, row) in enumerate(zip(bars, drop_df.itertuples())):
                height = bar.get_height()
                if height > 0.01:
                    ax4.text(
                        bar.get_x() + bar.get_width() / 2,
                        height + 0.01,
                        f"{height:.3f}",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                        weight="bold",
                    )
                    ax4.text(
                        bar.get_x() + bar.get_width() / 2,
                        height / 2,
                        f"↑{row.best_pattern}\n↓{row.worst_pattern}",
                        ha="center",
                        va="center",
                        fontsize=6,
                        color="white",
                        weight="bold",
                    )

            ax4.grid(True, axis="y", alpha=0.3)
        else:
            ax4.text(
                0.5,
                0.5,
                "Insufficient pattern data\nfor sensitivity analysis",
                ha="center",
                va="center",
                transform=ax4.transAxes,
                fontsize=12,
            )
            ax4.set_title(f"{task_name} - Pattern Sensitivity (Question-Based)")

    plt.tight_layout()
    return fig


def calculate_size_pattern_comparison_stats(data: pd.DataFrame, task_name: str) -> Dict:
    """Calculate comprehensive size pattern comparison statistics for a task."""
    stats = {}

    # Overall pattern performance for this task
    overall_pattern = (
        data.groupby("size_pattern")["correct"].agg(["mean", "count", "std"]).to_dict()
    )
    stats["overall_by_size_pattern"] = overall_pattern

    # Split data for task-specific analysis
    full_output_data, question_data = split_by_question_type(data)

    # Task type breakdown
    if not full_output_data.empty:
        full_output_pattern = (
            full_output_data.groupby("size_pattern")["correct"]
            .agg(["mean", "count"])
            .to_dict()
        )
        stats["full_output_by_size_pattern"] = full_output_pattern

    if not question_data.empty:
        question_pattern = (
            question_data.groupby("size_pattern")["correct"]
            .agg(["mean", "count"])
            .to_dict()
        )
        stats["question_based_by_size_pattern"] = question_pattern

    # Per-model pattern sensitivity
    model_pattern_sensitivity = {}

    for model in data["model"].unique():
        model_data = data[data["model"] == model]
        if not model_data.empty:
            model_pattern_perf = model_data.groupby("size_pattern")["correct"].mean()
            if len(model_pattern_perf) > 1:
                best_pattern = model_pattern_perf.idxmax()
                worst_pattern = model_pattern_perf.idxmin()
                sensitivity = model_pattern_perf.max() - model_pattern_perf.min()

                model_pattern_sensitivity[model] = {
                    "best_pattern": best_pattern,
                    "worst_pattern": worst_pattern,
                    "sensitivity": sensitivity,
                    "performance_by_pattern": model_pattern_perf.to_dict(),
                }

    stats["model_pattern_sensitivity"] = model_pattern_sensitivity

    return stats


def create_improved_question_type_chart(
    question_data: pd.DataFrame, task_name: str, figsize=(16, 10), no_titles: bool = False
) -> plt.Figure:
    """
    Create an improved question type chart with input/output split and horizontal legend.

    Parameters:
    - question_data: DataFrame with question-based data only
    - task_name: Name of the task for the title
    - figsize: Figure size

    Returns:
    - matplotlib Figure with two subplots (input and output)
    """
    fig, (ax_input, ax_output) = plt.subplots(2, 1, figsize=figsize)

    # Filter data by target
    input_data = question_data[question_data["target"] == "input"]
    output_data = question_data[question_data["target"] == "output"]

    # Get unique question types (excluding full_output)
    all_question_types = sorted(
        [qt for qt in question_data["question_type"].unique() if qt != "full_output"]
    )

    # Get model ordering and colors
    models = get_model_family_order(question_data["model"].unique())
    colors = get_color_palette(models, "models")

    # Function to create bars for one subplot
    def create_subplot_bars(ax, data, title_suffix, target_type):
        if data.empty:
            ax.text(
                0.5,
                0.5,
                f"No {target_type} data available",
                ha="center",
                va="center",
                transform=ax.transAxes,
                fontsize=12,
            )
            ax.set_title(f"{task_name} - {title_suffix}")
            return

        # Calculate performance by question type and model
        perf_data = (
            data.groupby(["question_type", "model"])["correct"]
            .agg(["mean", "count"])
            .reset_index()
        )
        perf_data.columns = ["question_type", "model", "accuracy", "sample_count"]

        # Create the bar chart
        x = np.arange(len(all_question_types))
        width = 0.8 / len(models) if len(models) > 0 else 0.8

        bars_by_model = {}  # Store bars for legend

        for i, model in enumerate(models):
            model_data = perf_data[perf_data["model"] == model]

            # Align with all_question_types
            accuracies = []
            sample_counts = []
            for qt in all_question_types:
                qt_data = model_data[model_data["question_type"] == qt]
                if not qt_data.empty:
                    accuracies.append(qt_data["accuracy"].iloc[0])
                    sample_counts.append(qt_data["sample_count"].iloc[0])
                else:
                    accuracies.append(0)
                    sample_counts.append(0)

            offset = (i - (len(models) - 1) / 2) * width
            bars = ax.bar(
                x + offset,
                accuracies,
                width,
                label=model,
                color=colors[model],
                alpha=0.8,
                edgecolor="black",
                linewidth=0.5,
            )

            bars_by_model[model] = (
                bars[0] if bars else None
            )  # Store first bar for legend

            # Add accuracy and sample count labels
            for bar, acc, count in zip(bars, accuracies, sample_counts):
                if acc > 0:  # Only show for non-zero values
                    height = bar.get_height()
                    # Accuracy on top
                    ax.text(
                        bar.get_x() + bar.get_width() / 2.0,
                        height + 0.01,
                        f"{acc:.3f}",
                        ha="center",
                        va="bottom",
                        fontsize=8,
                        weight="bold",
                    )
                    # Sample count at bottom
                    ax.text(
                        bar.get_x() + bar.get_width() / 2.0,
                        0.02,
                        f"n={count}",
                        ha="center",
                        va="bottom",
                        fontsize=7,
                        alpha=0.7,
                        rotation=90,
                    )

        # Customize subplot
        ax.set_title(f"{task_name} - {title_suffix}", fontsize=12, pad=15)
        ax.set_ylabel("Accuracy", fontsize=11)
        ax.set_ylim(0, 1.0)
        ax.grid(True, axis="y", alpha=0.3)
        ax.set_xticks(x)
        ax.set_xticklabels(all_question_types, rotation=0, ha="right", fontsize=10)

        return bars_by_model

    # Create input subplot
    input_bars = create_subplot_bars(
        ax_input, input_data, "Question Performance - Input Targets", "input"
    )

    # Create output subplot
    output_bars = create_subplot_bars(
        ax_output, output_data, "Question Performance - Output Targets", "output"
    )

    # Only add x-label to bottom subplot
    ax_output.set_xlabel("Question Type", fontsize=11)

    # Create horizontal legend at the bottom
    # Use bars from whichever subplot has data, preferring input
    legend_bars = input_bars if input_bars else output_bars

    if legend_bars:
        # Create legend handles and labels
        legend_handles = []
        legend_labels = []

        for model in models:
            if model in legend_bars and legend_bars[model] is not None:
                legend_handles.append(legend_bars[model])
                legend_labels.append(model)

        # Add horizontal legend below the plots
        fig.legend(
            legend_handles,
            legend_labels,
            loc="lower center",
            bbox_to_anchor=(0.5, -0.05),
            ncol=min(len(legend_labels), 4),  # Max 4 columns to keep it reasonable
            fontsize=10,
            title="Models",
            title_fontsize=11,
            frameon=True,
            fancybox=True,
            shadow=True,
        )

    # Add overall statistics text box
    total_input_samples = len(input_data) if not input_data.empty else 0
    total_output_samples = len(output_data) if not output_data.empty else 0
    avg_input_acc = input_data["correct"].mean() if not input_data.empty else 0
    avg_output_acc = output_data["correct"].mean() if not output_data.empty else 0

    stats_text = (
        f"Input: {total_input_samples} samples, {avg_input_acc:.3f} avg accuracy\n"
        f"Output: {total_output_samples} samples, {avg_output_acc:.3f} avg accuracy"
    )

    fig.text(
        0.02,
        0.02,
        stats_text,
        fontsize=9,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7),
        verticalalignment="bottom",
    )

    # Adjust layout to make room for legend and stats
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)  # Make room for legend

    return fig


def generate_task_analysis(
    data: pd.DataFrame, task_name: str, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate detailed analysis for a specific task.
    UPDATED to create separate views for full_output and question-based data.
    """
    results = {"generated_files": [], "stats": {}, "task_name": task_name}

    if data.empty:
        if verbose:
            print(f"   ⚠️ No data found for task {task_name}")
        return results

    # Create task subdirectory
    task_dir = os.path.join(output_dir, f"task_{task_name}")
    os.makedirs(task_dir, exist_ok=True)

    # Split by question type
    full_output_data, question_data = split_by_question_type(data)

    # 1. IMPROVED Task Performance Matrix with Input/Output Split
    if not question_data.empty and "target" in question_data.columns:
        if verbose:
            print(
                f"   📊 Creating improved question type performance matrix for {task_name}..."
            )

        fig = create_improved_question_type_chart(
            question_data, task_name, figsize=(16, 10), no_titles=no_titles
        )

        filepath = f"{task_dir}/{task_name}_question_types.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Store stats for both input and output
        question_stats_input = (
            question_data[question_data["target"] == "input"]
            .groupby(["question_type", "model"])["correct"]
            .agg(["mean", "count"])
        )
        question_stats_output = (
            question_data[question_data["target"] == "output"]
            .groupby(["question_type", "model"])["correct"]
            .agg(["mean", "count"])
        )

        results["stats"]["by_question_type_input"] = question_stats_input.to_dict()
        results["stats"]["by_question_type_output"] = question_stats_output.to_dict()

    # 2. Input vs Output Breakdown - comparing targets for each question (QUESTION-BASED ONLY)
    if not question_data.empty and "target" in question_data.columns:
        targets = question_data["target"].unique()
        if len(targets) > 1:
            if verbose:
                print(f"   🔄 Creating input vs output breakdown for {task_name}...")

            # Split by target
            data_by_target = {}
            for target in targets:
                data_by_target[target] = question_data[
                    question_data["target"] == target
                ]

            fig = create_comparison_chart(
                data_by_target,
                f"{task_name} - Input vs Output Analysis (Question-Based Tasks)",
                comparison_labels=list(targets),
                ylabel="Accuracy",
                figsize=(14, 8),
            )

            filepath = f"{task_dir}/{task_name}_input_vs_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            target_stats = question_data.groupby(["target", "model"])["correct"].agg(
                ["mean", "count"]
            )
            results["stats"]["by_target"] = target_stats.to_dict()

    # 3. Input-Output Answer Transfer Analysis (QUESTION-BASED ONLY)
    transfer_results = generate_input_output_answer_transfer_analysis(
        data, task_name, task_dir, verbose
    )
    results["generated_files"].extend(transfer_results["generated_files"])
    results["stats"]["input_output_transfer"] = transfer_results["stats"]

    # 4a. System Prompt Impact - Full Output Tasks
    if not full_output_data.empty and "system_prompt" in full_output_data.columns:
        system_prompts = full_output_data["system_prompt"].unique()
        if len(system_prompts) > 1:
            if verbose:
                print(
                    f"   🎭 Creating system prompt analysis (full output) for {task_name}..."
                )

            fig = create_breakdown_chart(
                full_output_data,
                "system_prompt",
                f"{task_name} - System Prompt Impact (Full Output Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
            )

            filepath = f"{task_dir}/{task_name}_system_prompts_full_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            prompt_stats = full_output_data.groupby(["system_prompt", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["full_output_by_system_prompt"] = prompt_stats.to_dict()

    # 4b. System Prompt Impact - Question-Based Tasks
    if not question_data.empty and "system_prompt" in question_data.columns:
        system_prompts = question_data["system_prompt"].unique()
        if len(system_prompts) > 1:
            if verbose:
                print(
                    f"   🎭 Creating system prompt analysis (question-based) for {task_name}..."
                )

            fig = create_breakdown_chart(
                question_data,
                "system_prompt",
                f"{task_name} - System Prompt Impact (Question-Based Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
            )

            filepath = f"{task_dir}/{task_name}_system_prompts_question_based.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            prompt_stats = question_data.groupby(["system_prompt", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["question_based_by_system_prompt"] = prompt_stats.to_dict()

    # 5a. Graph Type Analysis - Full Output Tasks
    if not full_output_data.empty and "graph_type" in full_output_data.columns:
        graph_types = full_output_data["graph_type"].unique()
        if len(graph_types) > 1:
            if verbose:
                print(
                    f"   🎲 Creating graph type analysis (full output) for {task_name}..."
                )

            fig = create_breakdown_chart(
                full_output_data,
                "graph_type",
                f"{task_name} - Performance by Graph Type (Full Output Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
            )

            filepath = f"{task_dir}/{task_name}_graph_types_full_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            graph_stats = full_output_data.groupby(["graph_type", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["full_output_by_graph_type"] = graph_stats.to_dict()

    # 5b. Graph Type Analysis - Question-Based Tasks
    if not question_data.empty and "graph_type" in question_data.columns:
        graph_types = question_data["graph_type"].unique()
        if len(graph_types) > 1:
            if verbose:
                print(
                    f"   🎲 Creating graph type analysis (question-based) for {task_name}..."
                )

            fig = create_breakdown_chart(
                question_data,
                "graph_type",
                f"{task_name} - Performance by Graph Type (Question-Based Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(12, 8),
            )

            filepath = f"{task_dir}/{task_name}_graph_types_question_based.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            graph_stats = question_data.groupby(["graph_type", "model"])["correct"].agg(
                ["mean", "count"]
            )
            results["stats"]["question_based_by_graph_type"] = graph_stats.to_dict()

    # 6a. Encoding Impact - Full Output Tasks
    if not full_output_data.empty and "encoding" in full_output_data.columns:
        encodings = full_output_data["encoding"].unique()
        if len(encodings) > 1:
            if verbose:
                print(
                    f"   📝 Creating encoding analysis (full output) for {task_name}..."
                )

            fig = create_breakdown_chart(
                full_output_data,
                "encoding",
                f"{task_name} - Performance by Encoding Type (Full Output Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(10, 8),
            )

            filepath = f"{task_dir}/{task_name}_encodings_full_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            encoding_stats = full_output_data.groupby(["encoding", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["full_output_by_encoding"] = encoding_stats.to_dict()

    # 6b. Encoding Impact - Question-Based Tasks
    if not question_data.empty and "encoding" in question_data.columns:
        encodings = question_data["encoding"].unique()
        if len(encodings) > 1:
            if verbose:
                print(
                    f"   📝 Creating encoding analysis (question-based) for {task_name}..."
                )

            fig = create_breakdown_chart(
                question_data,
                "encoding",
                f"{task_name} - Performance by Encoding Type (Question-Based Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(10, 8),
            )

            filepath = f"{task_dir}/{task_name}_encodings_question_based.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            encoding_stats = question_data.groupby(["encoding", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["question_based_by_encoding"] = encoding_stats.to_dict()

    # 7a. Size Pattern Analysis - Full Output Tasks
    if not full_output_data.empty and "size_pattern" in full_output_data.columns:
        size_patterns = full_output_data["size_pattern"].unique()
        if len(size_patterns) > 1:
            if verbose:
                print(
                    f"   📏 Creating size pattern analysis (full output) for {task_name}..."
                )

            fig = create_breakdown_chart(
                full_output_data,
                "size_pattern",
                f"{task_name} - Performance by Size Pattern (Full Output Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(14, 8),
            )

            filepath = f"{task_dir}/{task_name}_size_patterns_full_output.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            pattern_stats = full_output_data.groupby(["size_pattern", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["full_output_by_size_pattern"] = pattern_stats.to_dict()

    # 7b. Size Pattern Analysis - Question-Based Tasks
    if not question_data.empty and "size_pattern" in question_data.columns:
        size_patterns = question_data["size_pattern"].unique()
        if len(size_patterns) > 1:
            if verbose:
                print(
                    f"   📏 Creating size pattern analysis (question-based) for {task_name}..."
                )

            fig = create_breakdown_chart(
                question_data,
                "size_pattern",
                f"{task_name} - Performance by Size Pattern (Question-Based Tasks)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(14, 8),
            )

            filepath = f"{task_dir}/{task_name}_size_patterns_question_based.png"
            save_plot(fig, filepath, no_titles=no_titles)
            results["generated_files"].append(filepath)

            # Store stats
            pattern_stats = question_data.groupby(["size_pattern", "model"])[
                "correct"
            ].agg(["mean", "count"])
            results["stats"]["question_based_by_size_pattern"] = pattern_stats.to_dict()

    # 8. Size Pattern Comparison (if both full_output and question_data exist)
    if (
        not full_output_data.empty
        and not question_data.empty
        and "size_pattern" in data.columns
        and len(data["size_pattern"].unique()) > 1
    ):
        if verbose:
            print(f"   🔄 Creating size pattern comparison for {task_name}...")

        fig = create_size_pattern_comparison_chart(data, task_name, figsize=(16, 10))
        filepath = f"{task_dir}/{task_name}_size_pattern_comparison.png"
        save_plot(fig, filepath, no_titles=no_titles)
        results["generated_files"].append(filepath)

        # Calculate comparison stats
        comparison_stats = calculate_size_pattern_comparison_stats(data, task_name)
        results["stats"]["size_pattern_comparison"] = comparison_stats

    return results


def generate_error_analysis(
    data: pd.DataFrame, output_dir: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Generate error analysis visualizations.
    (UNCHANGED from previous version)
    """
    results = {"generated_files": [], "stats": {}}

    # Create error analysis subdirectory
    error_dir = os.path.join(output_dir, "error_analysis")
    os.makedirs(error_dir, exist_ok=True)

    if verbose:
        print("   🔍 Analyzing error patterns...")

    # 1. Success Rate by Question Type
    _, question_data = split_by_question_type(data)

    if not question_data.empty:
        if verbose:
            print("   📊 Creating success rate analysis...")

        # Calculate success rates
        success_rates = (
            question_data.groupby(["question_type", "model"])["correct"]
            .agg(["mean", "count", "sum"])
            .reset_index()
        )
        success_rates.columns = [
            "question_type",
            "model",
            "success_rate",
            "total_attempts",
            "successes",
        ]

        # Create chart
        fig = create_breakdown_chart(
            question_data,
            "question_type",
            "Success Rate by Question Type (Question-Based Tasks)",
            ylabel="Success Rate",
            color_by="models",
            figsize=(14, 8),
        )

        filepath = f"{error_dir}/success_rates_by_question_type.png"
        save_plot(fig, filepath)
        results["generated_files"].append(filepath)

        # Store detailed stats
        results["stats"]["success_rates"] = success_rates.to_dict("records")

    # 2. Performance by Target (Input vs Output difficulty)
    if not question_data.empty and "target" in question_data.columns:
        if verbose:
            print("   🎯 Analyzing target difficulty...")

        target_difficulty = (
            question_data.groupby(["target", "question_type"])["correct"]
            .agg(["mean", "count"])
            .reset_index()
        )
        target_difficulty.columns = ["target", "question_type", "accuracy", "count"]

        # Create comparison chart
        data_by_target = {}
        for target in question_data["target"].unique():
            data_by_target[target] = question_data[question_data["target"] == target]

        fig = create_comparison_chart(
            data_by_target,
            "Target Difficulty Analysis - Input vs Output Questions",
            comparison_labels=list(data_by_target.keys()),
            ylabel="Accuracy",
            figsize=(14, 8),
        )

        filepath = f"{error_dir}/target_difficulty_analysis.png"
        save_plot(fig, filepath)
        results["generated_files"].append(filepath)

        # Store stats
        results["stats"]["target_difficulty"] = target_difficulty.to_dict("records")

    # 3. Task Difficulty Ranking (ALL DATA)
    if "benchmark" in data.columns:
        if verbose:
            print("   📈 Creating task difficulty ranking...")

        task_difficulty = (
            data.groupby("benchmark")["correct"].agg(["mean", "count"]).reset_index()
        )
        task_difficulty.columns = ["benchmark", "accuracy", "count"]
        task_difficulty = task_difficulty.sort_values("accuracy")

        # Create overall model performance chart focusing on the hardest tasks
        hardest_tasks = task_difficulty.head(10)["benchmark"].tolist()
        hardest_data = data[data["benchmark"].isin(hardest_tasks)]

        if not hardest_data.empty:
            fig = create_breakdown_chart(
                hardest_data,
                "benchmark",
                "Performance on Most Challenging Tasks (All Task Types)",
                ylabel="Accuracy",
                color_by="models",
                figsize=(16, 10),
            )

            filepath = f"{error_dir}/challenging_tasks_analysis.png"
            save_plot(fig, filepath)
            results["generated_files"].append(filepath)

        # Store stats
        results["stats"]["task_difficulty"] = task_difficulty.to_dict("records")

    return results


def generate_detailed_visualizations(
    data: pd.DataFrame, output_dir: str, verbose: bool = False, no_titles: bool = False
) -> Dict[str, Any]:
    """
    Generate all detailed visualizations.
    """
    if verbose:
        print("🔍 Generating Level 2 Detailed Visualizations...")

    # Create detailed subdirectory
    detailed_dir = os.path.join(output_dir, "detailed")
    os.makedirs(detailed_dir, exist_ok=True)

    all_results = {
        "detailed_dir": detailed_dir,
        "task_analyses": {},
        "error_analysis": {},
        "total_files": 0,
    }

    # Generate per-task analysis
    if "benchmark" in data.columns:
        tasks = data["benchmark"].unique()
        if verbose:
            print(f"  📊 Analyzing {len(tasks)} tasks individually...")

        for task in tasks:
            if verbose:
                print(f"    🎯 Analyzing task: {task}")

            task_data = data[data["benchmark"] == task]
            task_results = generate_task_analysis(
                task_data, task, detailed_dir, verbose, no_titles
            )
            all_results["task_analyses"][task] = task_results

    # Generate error analysis
    if verbose:
        print("  🔍 Generating error analysis...")
    error_results = generate_error_analysis(data, detailed_dir, verbose)
    all_results["error_analysis"] = error_results

    # Calculate total files
    total_files = sum(
        len(task_result["generated_files"])
        for task_result in all_results["task_analyses"].values()
    )
    total_files += len(error_results["generated_files"])
    all_results["total_files"] = total_files

    if verbose:
        print(f"  ✅ Generated {total_files} detailed visualizations in {detailed_dir}")

    return all_results
