"""
Improved Token Analysis Generator - Size-aware visualizations without inappropriate averaging.

This module generates visualizations that respect the dramatic differences in token
requirements across different graph sizes (5-250 nodes).
"""

import os
import json
from typing import Dict, Any
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from scripts.visualization.core.utils import setup_plot_style, save_plot
from scripts.visualization.core.prompt_token_data_collector import (
    collect_all_graph_token_data,
    get_token_summary_statistics,
)


def estimate_base_prompt_tokens() -> int:
    """
    Estimate tokens for the base prompt structure (without graphs).
    Based on the example prompt structure.
    """
    base_structure = """Below are X examples of input graphs and their corresponding output graphs.

IMPORTANT: Structure your response using XML tags as follows:

<thinking>
Your step-by-step analysis and reasoning here. Explain your thought process, identify patterns from the examples, and work through the problem systematically.
</thinking>

<answer>
Only the number.
</answer>

Using these examples, and this final input graph, answer the following question:

How many connected components are there in this input graph?

Remember: Use <thinking></thinking> tags for your analysis and <answer></answer> tags for your final response."""

    # Simple token estimation (words * 1.3 to account for tokenization)
    words = len(base_structure.split())
    return int(words * 1.3)


def generate_size_aware_encoding_efficiency(
    token_df: pd.DataFrame, output_dir: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Generate encoding efficiency comparison that shows how efficiency varies by size.
    FIXED: No longer averages across different graph sizes.
    """
    results = {"generated_files": [], "stats": {}}

    if verbose:
        print("   📊 Creating size-aware encoding efficiency analysis...")

    # Get available sizes, sorted
    available_sizes = sorted(token_df["size"].unique())

    # Create a comprehensive size-specific analysis
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Encoding efficiency by size (line plot)
    ax1 = axes[0, 0]

    # Calculate mean tokens by size and encoding for input/output separately
    size_encoding_stats = (
        token_df.groupby(["size", "encoding", "input_output"])["token_count"]
        .agg(["mean", "std", "count"])
        .reset_index()
    )

    # Plot input graphs
    input_data = size_encoding_stats[size_encoding_stats["input_output"] == "input"]
    for encoding in input_data["encoding"].unique():
        enc_data = input_data[input_data["encoding"] == encoding]
        ax1.plot(
            enc_data["size"],
            enc_data["mean"],
            marker="o",
            label=f"{encoding} (input)",
            linewidth=2,
            linestyle="-",
        )

    # Plot output graphs
    output_data = size_encoding_stats[size_encoding_stats["input_output"] == "output"]
    for encoding in output_data["encoding"].unique():
        enc_data = output_data[output_data["encoding"] == encoding]
        ax1.plot(
            enc_data["size"],
            enc_data["mean"],
            marker="s",
            label=f"{encoding} (output)",
            linewidth=2,
            linestyle="--",
        )

    ax1.set_xlabel("Graph Size (nodes)")
    ax1.set_ylabel("Average Token Count")
    ax1.set_title("Token Requirements by Graph Size and Encoding")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xscale("log")
    ax1.set_yscale("log")

    # 2. Efficiency ratio heatmap (adjacency/incident)
    ax2 = axes[0, 1]

    # Calculate efficiency ratios for each size and input/output
    ratios = []
    for size in available_sizes:
        for io_type in ["input", "output"]:
            size_io_data = token_df[
                (token_df["size"] == size) & (token_df["input_output"] == io_type)
            ]

            adj_mean = size_io_data[size_io_data["encoding"] == "adjacency"][
                "token_count"
            ].mean()
            inc_mean = size_io_data[size_io_data["encoding"] == "incident"][
                "token_count"
            ].mean()

            if not pd.isna(adj_mean) and not pd.isna(inc_mean) and inc_mean > 0:
                ratio = adj_mean / inc_mean
                ratios.append(
                    {
                        "size": size,
                        "input_output": io_type,
                        "ratio": ratio,
                        "better_encoding": "incident" if ratio > 1.0 else "adjacency",
                    }
                )

    if ratios:
        ratio_df = pd.DataFrame(ratios)
        ratio_pivot = ratio_df.pivot(
            index="size", columns="input_output", values="ratio"
        )

        sns.heatmap(
            ratio_pivot,
            annot=True,
            fmt=".2f",
            cmap="RdYlGn_r",
            center=1.0,
            ax=ax2,
            cbar_kws={"label": "Adjacency/Incident Ratio"},
        )
        ax2.set_title(
            "Encoding Efficiency Ratio by Size\n(<1.0 = Adjacency better, >1.0 = Incident better)"
        )
        ax2.set_xlabel("Graph Type")
        ax2.set_ylabel("Graph Size (nodes)")

    # 3. Token differences between encodings
    ax3 = axes[1, 0]

    # Calculate absolute differences
    differences = []
    for size in available_sizes:
        for io_type in ["input", "output"]:
            size_io_data = token_df[
                (token_df["size"] == size) & (token_df["input_output"] == io_type)
            ]

            adj_mean = size_io_data[size_io_data["encoding"] == "adjacency"][
                "token_count"
            ].mean()
            inc_mean = size_io_data[size_io_data["encoding"] == "incident"][
                "token_count"
            ].mean()

            if not pd.isna(adj_mean) and not pd.isna(inc_mean):
                diff = adj_mean - inc_mean
                differences.append(
                    {"size": size, "input_output": io_type, "difference": diff}
                )

    if differences:
        diff_df = pd.DataFrame(differences)

        for io_type in diff_df["input_output"].unique():
            io_data = diff_df[diff_df["input_output"] == io_type]
            ax3.plot(
                io_data["size"],
                io_data["difference"],
                marker="o",
                label=io_type,
                linewidth=2,
            )

        ax3.axhline(y=0, color="black", linestyle="--", alpha=0.5)
        ax3.set_xlabel("Graph Size (nodes)")
        ax3.set_ylabel("Token Difference (Adjacency - Incident)")
        ax3.set_title("Absolute Token Difference Between Encodings")
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        ax3.set_xscale("log")

    # 4. Encoding decision recommendations
    ax4 = axes[1, 1]

    if ratios:
        # Show which encoding is better for each size
        recommendation_data = []
        for size in available_sizes:
            size_ratios = ratio_df[ratio_df["size"] == size]
            if not size_ratios.empty:
                avg_ratio = size_ratios["ratio"].mean()
                better_encoding = "Incident" if avg_ratio > 1.0 else "Adjacency"
                efficiency_gain = abs(1.0 - avg_ratio)

                recommendation_data.append(
                    {
                        "size": size,
                        "recommended_encoding": better_encoding,
                        "efficiency_gain": efficiency_gain,
                    }
                )

        if recommendation_data:
            rec_df = pd.DataFrame(recommendation_data)

            colors = [
                "green" if enc == "Incident" else "blue"
                for enc in rec_df["recommended_encoding"]
            ]
            bars = ax4.bar(
                range(len(rec_df)), rec_df["efficiency_gain"], color=colors, alpha=0.7
            )

            ax4.set_xlabel("Graph Size")
            ax4.set_ylabel("Efficiency Gain")
            ax4.set_title(
                "Recommended Encoding by Graph Size\n(Green=Incident, Blue=Adjacency)"
            )
            ax4.set_xticks(range(len(rec_df)))
            ax4.set_xticklabels([f"{size}" for size in rec_df["size"]], rotation=45)

            # Add recommendation labels
            for _, (bar, rec) in enumerate(zip(bars, rec_df["recommended_encoding"])):
                ax4.text(
                    bar.get_x() + bar.get_width() / 2,
                    bar.get_height() + 0.01,
                    rec,
                    ha="center",
                    va="bottom",
                    fontsize=8,
                    rotation=90,
                )

    plt.tight_layout()
    filepath = f"{output_dir}/01_size_aware_encoding_efficiency.png"
    save_plot(fig, filepath, "Size-Aware Encoding Efficiency Analysis")
    results["generated_files"].append(filepath)

    # Store stats
    results["stats"] = {
        "size_encoding_stats": size_encoding_stats.to_dict("records"),
        "efficiency_ratios": ratios,
        "differences": differences if "differences" in locals() else [],
        "recommendations": (
            recommendation_data if "recommendation_data" in locals() else []
        ),
    }

    return results


def generate_prompt_construction_table(
    token_df: pd.DataFrame, output_dir: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Generate the requested table showing token requirements by size and number of examples.
    Table structure:
    - Columns: Graph sizes (5, 10, 15, etc.)
    - Rows: Number of examples (0=base, 1=base+final, 2=base+1pair+final, etc.)
    """
    results = {"generated_files": [], "stats": {}}

    if verbose:
        print("   📋 Creating prompt construction tables...")

    # Get available sizes and encodings
    available_sizes = sorted(token_df["size"].unique())
    available_encodings = sorted(token_df["encoding"].unique())

    base_prompt_tokens = estimate_base_prompt_tokens()

    # Calculate average tokens by size, encoding, and input/output type
    avg_tokens = (
        token_df.groupby(["size", "encoding", "input_output"])["token_count"]
        .mean()
        .reset_index()
    )

    for encoding in available_encodings:
        enc_data = avg_tokens[avg_tokens["encoding"] == encoding]

        # Create prompt construction table for this encoding
        table_data = []
        row_labels = [
            "0: Base prompt only",
            "1: Base + final input",
            "2: Base + 1 pair + final input",
            "3: Base + 2 pairs + final input",
            "4: Base + 3 pairs + final input",
        ]

        for i, row_label in enumerate(row_labels):
            row = {"Configuration": row_label}

            for size in available_sizes:
                # Get input and output token counts for this size
                input_tokens = enc_data[
                    (enc_data["size"] == size) & (enc_data["input_output"] == "input")
                ]["token_count"]
                output_tokens = enc_data[
                    (enc_data["size"] == size) & (enc_data["input_output"] == "output")
                ]["token_count"]

                input_avg = input_tokens.iloc[0] if not input_tokens.empty else 0
                output_avg = output_tokens.iloc[0] if not output_tokens.empty else 0

                # Calculate total tokens for this configuration
                if i == 0:  # Base prompt only
                    total_tokens = base_prompt_tokens
                elif i == 1:  # Base + final input
                    total_tokens = base_prompt_tokens + input_avg
                else:  # Base + (i-1) pairs + final input
                    num_pairs = i - 1
                    total_tokens = (
                        base_prompt_tokens
                        + (num_pairs * (input_avg + output_avg))
                        + input_avg
                    )

                row[f"{size} nodes"] = int(total_tokens)

            table_data.append(row)

        # Create DataFrame and save as CSV
        table_df = pd.DataFrame(table_data)
        csv_path = f"{output_dir}/prompt_tokens_table_{encoding}.csv"
        table_df.to_csv(csv_path, index=False)
        results["generated_files"].append(csv_path)

        # Create visualization of the table
        fig, ax = plt.subplots(figsize=(14, 8))

        # Prepare data for heatmap (exclude configuration column)
        heatmap_data = table_df.set_index("Configuration").astype(float)

        # Create heatmap
        sns.heatmap(
            heatmap_data,
            annot=True,
            fmt=".0f",
            cmap="YlOrRd",
            ax=ax,
            cbar_kws={"label": "Total Prompt Tokens"},
        )

        ax.set_title(
            f"Prompt Token Requirements - {encoding.title()} Encoding\n"
            f"Rows: Example configurations, Columns: Graph sizes"
        )
        ax.set_xlabel("Graph Size (nodes)")
        ax.set_ylabel("Prompt Configuration")

        plt.tight_layout()
        heatmap_path = f"{output_dir}/03_prompt_tokens_heatmap_{encoding}.png"
        save_plot(fig, heatmap_path)
        results["generated_files"].append(heatmap_path)

        # Store table data in results
        results["stats"][f"{encoding}_table"] = table_df.to_dict("records")

    # Create comparison visualization showing both encodings
    if len(available_encodings) >= 2:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

        for idx, encoding in enumerate(
            available_encodings[:2]
        ):  # Show first 2 encodings
            enc_data = avg_tokens[avg_tokens["encoding"] == encoding]
            ax = ax1 if idx == 0 else ax2

            # Calculate data for this encoding
            config_data = []
            sizes_for_plot = available_sizes[
                :8
            ]  # Limit to first 8 sizes for readability

            for i in range(5):  # 0-4 example configurations
                config_tokens = []
                for size in sizes_for_plot:
                    input_tokens = enc_data[
                        (enc_data["size"] == size)
                        & (enc_data["input_output"] == "input")
                    ]["token_count"]
                    output_tokens = enc_data[
                        (enc_data["size"] == size)
                        & (enc_data["input_output"] == "output")
                    ]["token_count"]

                    input_avg = input_tokens.iloc[0] if not input_tokens.empty else 0
                    output_avg = output_tokens.iloc[0] if not output_tokens.empty else 0

                    if i == 0:
                        total = base_prompt_tokens
                    elif i == 1:
                        total = base_prompt_tokens + input_avg
                    else:
                        total = (
                            base_prompt_tokens
                            + ((i - 1) * (input_avg + output_avg))
                            + input_avg
                        )

                    config_tokens.append(total)

                config_data.append(config_tokens)

            # Plot lines for each configuration
            config_labels = [
                "Base only",
                "Base+final",
                "Base+1pair+final",
                "Base+2pairs+final",
                "Base+3pairs+final",
            ]

            for _, (tokens, label) in enumerate(
                zip(config_data, config_labels)
            ):
                ax.plot(sizes_for_plot, tokens, marker="o", label=label, linewidth=2)

            ax.set_xlabel("Graph Size (nodes)")
            ax.set_ylabel("Total Prompt Tokens")
            ax.set_title(f"{encoding.title()} Encoding")
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_xscale("log")
            ax.set_yscale("log")

        plt.tight_layout()
        comparison_path = f"{output_dir}/04_encoding_comparison_by_config.png"
        save_plot(fig, comparison_path)
        results["generated_files"].append(comparison_path)

    return results


def generate_token_scaling_analysis(
    token_df: pd.DataFrame, output_dir: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Generate token scaling analysis showing how tokens grow with graph size.
    This one is GOOD because it shows the RELATIONSHIP between size and tokens,
    rather than inappropriately averaging across different sizes.
    """
    results = {"generated_files": [], "stats": {}}

    if verbose:
        print("   📈 Creating token scaling analysis...")

    # Calculate mean tokens by size, encoding, and input/output
    scaling_stats = (
        token_df.groupby(["size", "encoding", "input_output"])["token_count"]
        .agg(["mean", "std", "count"])
        .reset_index()
    )

    # Create line plots showing scaling
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Input scaling by encoding
    input_data = scaling_stats[scaling_stats["input_output"] == "input"]
    for encoding in input_data["encoding"].unique():
        enc_data = input_data[input_data["encoding"] == encoding]
        ax1.plot(
            enc_data["size"], enc_data["mean"], marker="o", label=encoding, linewidth=2
        )
        ax1.fill_between(
            enc_data["size"],
            enc_data["mean"] - enc_data["std"],
            enc_data["mean"] + enc_data["std"],
            alpha=0.2,
        )

    ax1.set_title("Input Graph Token Scaling by Size")
    ax1.set_xlabel("Graph Size (nodes)")
    ax1.set_ylabel("Average Token Count")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xscale("log")
    ax1.set_yscale("log")

    # 2. Output scaling by encoding
    output_data = scaling_stats[scaling_stats["input_output"] == "output"]
    for encoding in output_data["encoding"].unique():
        enc_data = output_data[output_data["encoding"] == encoding]
        ax2.plot(
            enc_data["size"], enc_data["mean"], marker="s", label=encoding, linewidth=2
        )
        ax2.fill_between(
            enc_data["size"],
            enc_data["mean"] - enc_data["std"],
            enc_data["mean"] + enc_data["std"],
            alpha=0.2,
        )

    ax2.set_title("Output Graph Token Scaling by Size")
    ax2.set_xlabel("Graph Size (nodes)")
    ax2.set_ylabel("Average Token Count")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xscale("log")
    ax2.set_yscale("log")

    # 3. Combined scaling comparison
    for input_output in ["input", "output"]:
        io_data = scaling_stats[scaling_stats["input_output"] == input_output]
        for encoding in io_data["encoding"].unique():
            enc_data = io_data[io_data["encoding"] == encoding]
            label = f"{encoding} ({input_output})"
            linestyle = "-" if input_output == "input" else "--"
            ax3.plot(
                enc_data["size"],
                enc_data["mean"],
                marker="o",
                label=label,
                linewidth=2,
                linestyle=linestyle,
            )

    ax3.set_title("Token Scaling: Input vs Output Comparison")
    ax3.set_xlabel("Graph Size (nodes)")
    ax3.set_ylabel("Average Token Count")
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_xscale("log")
    ax3.set_yscale("log")

    # 4. Token growth rate analysis
    growth_rates = []
    for encoding in scaling_stats["encoding"].unique():
        for input_output in scaling_stats["input_output"].unique():
            subset = scaling_stats[
                (scaling_stats["encoding"] == encoding)
                & (scaling_stats["input_output"] == input_output)
            ].sort_values("size")

            if len(subset) >= 2:
                # Calculate growth rate between consecutive sizes
                for i in range(1, len(subset)):
                    prev_tokens = subset.iloc[i - 1]["mean"]
                    curr_tokens = subset.iloc[i]["mean"]
                    prev_size = subset.iloc[i - 1]["size"]
                    curr_size = subset.iloc[i]["size"]

                    if curr_size != prev_size:  # Avoid division by zero
                        growth_rate = (curr_tokens - prev_tokens) / (
                            curr_size - prev_size
                        )
                        growth_rates.append(
                            {
                                "encoding": encoding,
                                "input_output": input_output,
                                "size_from": prev_size,
                                "size_to": curr_size,
                                "tokens_per_node": growth_rate,
                            }
                        )

    if growth_rates:
        growth_df = pd.DataFrame(growth_rates)

        # Create a more readable visualization of growth rates
        growth_summary = (
            growth_df.groupby(["encoding", "input_output"])["tokens_per_node"]
            .agg(["mean", "std", "min", "max"])
            .reset_index()
        )

        # Bar plot showing average growth rates
        bar_data = []
        labels = []
        colors = []
        color_map = {"adjacency": "blue", "incident": "green"}

        for _, row in growth_summary.iterrows():
            bar_data.append(row["mean"])
            labels.append(f"{row['encoding']}\n({row['input_output']})")
            colors.append(color_map.get(row["encoding"], "gray"))

        bars = ax4.bar(range(len(bar_data)), bar_data, color=colors, alpha=0.7)

        # Add error bars
        for i, (bar, row) in enumerate(zip(bars, growth_summary.itertuples())):
            ax4.errorbar(
                i, row.mean, yerr=row.std, fmt="none", color="black", capsize=5
            )
            # Add text label with value
            ax4.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + row.std + 0.5,
                f"{row.mean:.1f}",
                ha="center",
                va="bottom",
                fontsize=9,
            )

        ax4.set_title("Average Token Growth Rate\n(tokens per additional node)")
        ax4.set_ylabel("Tokens per Additional Node")
        ax4.set_xlabel("Encoding and Graph Type")
        ax4.set_xticks(range(len(labels)))
        ax4.set_xticklabels(labels)
        ax4.grid(True, axis="y", alpha=0.3)

    plt.tight_layout()
    filepath = f"{output_dir}/02_token_scaling_analysis.png"
    save_plot(fig, filepath, "Token Scaling Analysis by Graph Size")
    results["generated_files"].append(filepath)

    # Store scaling stats
    results["stats"] = {
        "scaling_stats": scaling_stats.to_dict("records"),
        "growth_rates": growth_rates if growth_rates else [],
    }

    return results


def generate_task_complexity_by_actual_size(
    token_df: pd.DataFrame, output_dir: str, verbose: bool = False
) -> Dict[str, Any]:
    """
    Generate task complexity analysis showing exact token requirements for each available graph size.
    Uses actual sizes (5, 10, 15, 25, etc.) rather than arbitrary groupings.
    """
    results = {"generated_files": [], "stats": {}}

    if verbose:
        print("   🎯 Creating task complexity analysis by actual graph size...")

    # Get available sizes, sorted
    available_sizes = sorted(token_df["size"].unique())

    if verbose:
        print(f"      Found {len(available_sizes)} different sizes: {available_sizes}")

    # Calculate task complexity by actual size
    task_complexity = (
        token_df.groupby(["task", "size", "encoding", "input_output"])["token_count"]
        .mean()
        .reset_index()
    )

    # Create visualization for each encoding
    available_encodings = sorted(token_df["encoding"].unique())

    # Determine figure size based on number of sizes (more sizes = wider figure)
    fig_width = max(16, len(available_sizes) * 1.5)
    fig_height = max(8, len(available_encodings) * 6)

    fig, axes = plt.subplots(
        len(available_encodings), 2, figsize=(fig_width, fig_height)
    )
    if len(available_encodings) == 1:
        axes = axes.reshape(1, -1)

    for enc_idx, encoding in enumerate(available_encodings):
        enc_data = task_complexity[task_complexity["encoding"] == encoding]

        # Input complexity heatmap
        ax_input = axes[enc_idx, 0] if len(available_encodings) > 1 else axes[0]
        input_data = enc_data[enc_data["input_output"] == "input"]

        if not input_data.empty:
            input_pivot = input_data.pivot_table(
                index="task", columns="size", values="token_count"
            )
            # Ensure columns are in the right order (sorted sizes)
            input_pivot = input_pivot.reindex(columns=available_sizes)

            # Decide whether to show annotations based on number of sizes
            show_annotations = len(available_sizes) <= 12

            sns.heatmap(
                input_pivot,
                annot=show_annotations,
                fmt=".0f",
                cmap="Blues",
                ax=ax_input,
                cbar_kws={"label": "Avg Tokens"},
            )
            ax_input.set_title(f"Input Graph Token Requirements - {encoding.title()}")
            ax_input.set_xlabel("Graph Size (nodes)")
            ax_input.set_ylabel("Task")

            # Rotate x-axis labels if there are many sizes
            if len(available_sizes) > 8:
                ax_input.tick_params(axis="x", rotation=45)

        # Output complexity heatmap
        ax_output = axes[enc_idx, 1] if len(available_encodings) > 1 else axes[1]
        output_data = enc_data[enc_data["input_output"] == "output"]

        if not output_data.empty:
            output_pivot = output_data.pivot_table(
                index="task", columns="size", values="token_count"
            )
            # Ensure columns are in the right order (sorted sizes)
            output_pivot = output_pivot.reindex(columns=available_sizes)

            # Decide whether to show annotations based on number of sizes
            show_annotations = len(available_sizes) <= 12

            sns.heatmap(
                output_pivot,
                annot=show_annotations,
                fmt=".0f",
                cmap="Reds",
                ax=ax_output,
                cbar_kws={"label": "Avg Tokens"},
            )
            ax_output.set_title(f"Output Graph Token Requirements - {encoding.title()}")
            ax_output.set_xlabel("Graph Size (nodes)")
            ax_output.set_ylabel("Task")

            # Rotate x-axis labels if there are many sizes
            if len(available_sizes) > 8:
                ax_output.tick_params(axis="x", rotation=45)

    plt.tight_layout()
    filepath = f"{output_dir}/05_task_complexity_by_actual_size.png"
    save_plot(fig, filepath, "Task Complexity by Actual Graph Size")
    results["generated_files"].append(filepath)

    # Also create a summary table showing the range of token requirements
    summary_stats = []
    for task in token_df["task"].unique():
        for encoding in available_encodings:
            for io_type in ["input", "output"]:
                task_data = token_df[
                    (token_df["task"] == task)
                    & (token_df["encoding"] == encoding)
                    & (token_df["input_output"] == io_type)
                ]

                if not task_data.empty:
                    size_stats = task_data.groupby("size")["token_count"].mean()

                    summary_stats.append(
                        {
                            "task": task,
                            "encoding": encoding,
                            "type": io_type,
                            "min_size": size_stats.index.min(),
                            "max_size": size_stats.index.max(),
                            "min_tokens": size_stats.min(),
                            "max_tokens": size_stats.max(),
                            "tokens_range": size_stats.max() - size_stats.min(),
                            "size_count": len(size_stats),
                        }
                    )

    if summary_stats:
        summary_df = pd.DataFrame(summary_stats)
        summary_path = f"{output_dir}/task_complexity_summary.csv"
        summary_df.to_csv(summary_path, index=False)
        results["generated_files"].append(summary_path)

        # Store in stats
        results["stats"]["complexity_summary"] = summary_stats

    # Store detailed stats
    results["stats"]["task_complexity_by_actual_size"] = task_complexity.to_dict(
        "records"
    )
    results["stats"]["available_sizes"] = available_sizes

    return results


def generate_prompt_token_visualizations(
    datasets_dir: str = "datasets",
    output_dir: str = "visualizations",
    verbose: bool = False,
) -> Dict[str, Any]:
    """
    Generate improved prompt token analysis visualizations that don't inappropriately
    average across different graph sizes.
    """
    if verbose:
        print("🔤 Generating Improved Prompt Token Analysis...")

    # Create prompt token analysis subdirectory
    token_dir = os.path.join(output_dir, "prompt_token_analysis")
    os.makedirs(token_dir, exist_ok=True)

    # Set up plotting style
    setup_plot_style()

    # Collect token data
    if verbose:
        print("  📊 Collecting token data...")
    token_df, structure = collect_all_graph_token_data(datasets_dir)

    if token_df.empty:
        print("❌ No token data found. Check datasets directory.")
        return {"token_dir": token_dir, "total_files": 0, "error": "No data found"}

    # Save raw data
    token_df.to_csv(f"{token_dir}/token_data_raw.csv", index=False)

    all_results = {
        "token_dir": token_dir,
        "data_files": [f"{token_dir}/token_data_raw.csv"],
        "encoding_analysis": {},
        "scaling_analysis": {},  # Added this
        "table_analysis": {},
        "task_complexity": {},
        "summary_statistics": {},
        "dataset_structure": structure,
        "total_files": 1,  # Start with raw data file
    }

    # 1. Size-aware encoding efficiency analysis (IMPROVED)
    if verbose:
        print("  📊 Generating size-aware encoding analysis...")
    encoding_results = generate_size_aware_encoding_efficiency(
        token_df, token_dir, verbose
    )
    all_results["encoding_analysis"] = encoding_results
    all_results["total_files"] += len(encoding_results["generated_files"])

    # 2. Prompt construction tables (NEW)
    if verbose:
        print("  📋 Generating prompt construction tables...")
    table_results = generate_prompt_construction_table(token_df, token_dir, verbose)
    all_results["table_analysis"] = table_results
    all_results["total_files"] += len(table_results["generated_files"])

    # 3. Token scaling analysis (KEPT - this one is good)
    if verbose:
        print("  📈 Generating token scaling analysis...")
    scaling_results = generate_token_scaling_analysis(token_df, token_dir, verbose)
    all_results["scaling_analysis"] = scaling_results
    all_results["total_files"] += len(scaling_results["generated_files"])

    # 4. Task complexity by actual sizes (IMPROVED)
    if verbose:
        print("  🎯 Generating task complexity by actual graph sizes...")
    complexity_results = generate_task_complexity_by_actual_size(
        token_df, token_dir, verbose
    )
    all_results["task_complexity"] = complexity_results
    all_results["total_files"] += len(complexity_results["generated_files"])

    # 4. Generate summary statistics
    summary_stats = get_token_summary_statistics(token_df, structure)
    all_results["summary_statistics"] = summary_stats

    # Save summary
    with open(f"{token_dir}/improved_token_analysis_summary.json", "w", encoding="utf-8") as f:
        json.dump(summary_stats, f, indent=2, default=str)
    all_results["total_files"] += 1

    # Create README
    create_improved_analysis_readme(all_results, token_dir, verbose)
    all_results["total_files"] += 1

    if verbose:
        print(
            f"  ✅ Generated {all_results['total_files']} improved token analysis files"
        )

    return all_results


def create_improved_analysis_readme(
    results: Dict, output_dir: str, verbose: bool = False
):
    """Create README for the improved token analysis."""

    if verbose:
        print("   📋 Creating improved analysis README...")

    readme_content = f"""# Improved Prompt Token Analysis

## Key Improvements

### ❌ What We Fixed
- **No more size averaging**: Previous visualizations inappropriately averaged token counts across graph sizes from 5 to 250+ nodes
- **Removed misleading charts**: Eliminated visualizations that mixed dramatically different graph complexities
- **Size-unaware comparisons**: Fixed analyses that didn't account for the exponential growth in token requirements

### ✅ What We Added

#### 1. Size-Aware Encoding Efficiency Analysis
- **Per-size comparisons**: Shows how adjacency vs incident encoding efficiency varies by graph size
- **Efficiency ratios**: Heatmap showing which encoding is better for each size and graph type
- **Decision recommendations**: Clear guidance on which encoding to use for different graph sizes

#### 2. Token Scaling Analysis (Preserved)
- **Growth patterns**: Shows how token requirements scale with graph size (log-log plots)
- **Encoding comparison**: Compare how adjacency vs incident encodings scale differently
- **Growth rate analysis**: Tokens per additional node for different configurations
- **Input vs output**: Compare token scaling between input and output graphs

#### 3. Prompt Construction Tables
- **Structured token budgeting**: Tables showing exact token requirements for different configurations
- **Configuration rows**:
  - Row 0: Base prompt only (~{estimate_base_prompt_tokens()} tokens)
  - Row 1: Base + final input graph  
  - Row 2: Base + 1 input-output pair + final input
  - Row 3: Base + 2 input-output pairs + final input
  - Row 4: Base + 3 input-output pairs + final input
- **Size columns**: Separate columns for each graph size (5, 10, 15, 25, etc. nodes)
- **Encoding-specific**: Separate tables for adjacency and incident encodings

#### 4. Task Complexity by Actual Graph Size
- **Real size columns**: Uses actual discovered sizes (5, 10, 15, 25, 50, 100, etc.) as columns
- **No arbitrary groupings**: Shows exact token requirements for each specific graph size
- **Task-encoding heatmaps**: Separate heatmaps for input/output graphs and each encoding
- **Summary table**: CSV with token ranges and statistics for each task

## Files Generated

### Data Files
- `token_data_raw.csv` - Individual file token counts
- `improved_token_analysis_summary.json` - Summary statistics

### Visualizations
1. `01_size_aware_encoding_efficiency.png` - Comprehensive encoding comparison by size
2. `02_token_scaling_analysis.png` - How token requirements scale with graph size
3. `03_prompt_tokens_heatmap_adjacency.png` - Adjacency encoding token table
4. `03_prompt_tokens_heatmap_incident.png` - Incident encoding token table  
5. `04_encoding_comparison_by_config.png` - Side-by-side encoding comparison
6. `05_task_complexity_by_actual_size.png` - Task complexity heatmaps by actual graph size

### CSV Tables
- `prompt_tokens_table_adjacency.csv` - Exact token counts for adjacency encoding
- `prompt_tokens_table_incident.csv` - Exact token counts for incident encoding
- `task_complexity_summary.csv` - Token ranges and statistics by task

## How to Use These Results

### For LLM Cost Planning
1. **Check the CSV tables** to get exact token counts for your intended configuration
2. **Use size-specific recommendations** from the encoding efficiency analysis
3. **Budget appropriately** - token requirements grow exponentially with graph size

### For Encoding Decisions
1. **Small graphs (≤10 nodes)**: Check the efficiency ratio heatmap
2. **Large graphs (>50 nodes)**: Encoding choice becomes critical for cost control
3. **Task-specific**: Some tasks may favor one encoding over another regardless of size
4. **Exact size guidance**: Use the actual size columns for precise decision-making

### For Prompt Design
1. **Number of examples**: Use the configuration tables to balance cost vs. performance
2. **Graph size selection**: Consider using smaller examples when possible
3. **Mixed sizes**: Be aware of the token cost implications when using varied example sizes

## Key Insights

{f"- **Dataset coverage**: Analyzed {results['dataset_structure']['total_files']:,} graph files across {len(results['dataset_structure']['sizes'])} different sizes" if 'dataset_structure' in results else ""}
{f"- **Exact size analysis**: From {min(results['dataset_structure']['sizes'])} to {max(results['dataset_structure']['sizes'])} nodes with individual columns for each size" if 'dataset_structure' in results and results['dataset_structure']['sizes'] else ""}
- **No more averaging**: All comparisons now respect the dramatic size differences
- **Practical guidance**: Clear recommendations for encoding and configuration choices  
- **Cost transparency**: Exact token counts for budgeting LLM API usage
- **Granular complexity**: Task complexity shown for each specific graph size

## Warning About Previous Analysis

If you see token analyses that show "average" values across all graph sizes, they are likely misleading due to the vast differences between small (5-node) and large (250+ node) graphs. Always use size-specific analysis for meaningful results.

---
*Generated by Improved Graph-Based ARC Prompt Token Analysis System*
"""

    readme_path = f"{output_dir}/README_IMPROVED_ANALYSIS.md"
    with open(readme_path, "w", encoding="utf-8") as f:
        f.write(readme_content)
