import json
import os
from typing import Any, Dict, List

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

# Set up aesthetically pleasing style for publication-worthy plots
plt.style.use("seaborn-v0_8-whitegrid")
mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams["font.sans-serif"] = ["Arial"]
mpl.rcParams["axes.edgecolor"] = "#333333"
mpl.rcParams["axes.linewidth"] = 1.2
mpl.rcParams["xtick.major.width"] = 1.2
mpl.rcParams["ytick.major.width"] = 1.2
mpl.rcParams["axes.grid"] = True
mpl.rcParams["grid.alpha"] = 0.3

# Load JSON data
with open("src/eliciting_contexts/benchmark/results/diagram_stories.json", "r") as f:
    # with open('diagram_stories.json', 'r') as f:
    data = json.load(f)


result_sets_config = [
    ("human", "Human", "Human", "#1f77b4"),  # blue
    ("GPT4o", "BlackBox", "GPT-4o", "#2ca02c"),  # green
    ("EPO", "EPO", "EPO", "#FFA500"),  # orange
    ("EPOAssistHobbled-Final", "EPO+", "EPO-Ast.", "#FF0000"),  # red
    ("GCG", "GCG", "GCG", "#9467bd"),  # purple
]
# Define colors and markers for each method with improved color scheme
# Original data keys and their display names
methods = [result_set[1] for result_set in result_sets_config]
display_names = [result_set[2] for result_set in result_sets_config]
colors = [result_set[3] for result_set in result_sets_config]
markers = ["s", "X", "D", "^", "o"]  # matched markers to new order


# Function to select a specific example
def select_example(entries: List[Dict[str, Any]], index: int = 0) -> str:
    if len(entries) <= index:
        return entries[0]["input_text"]
    return entries[index]["input_text"]


# Function to get desired/undesired words
def get_completion_words(entries: List[Dict[str, Any]], index: int = 0) -> str:
    if len(entries) <= index:
        entry = entries[0]
    else:
        entry = entries[index]

    if "desired_word" in entry and "undesired_word" in entry:
        return f" [{entry['desired_word']} / {entry['undesired_word']}]"
    return ""


# Bold formatter for text
def make_bold(text: str) -> str:
    # Use matplotlib's fontproperties to make text bold
    return rf"$\bf{{{text}}}$"


# Story types
story_types = list(data.keys())

# FIGURE 1: Create the plot with scatter plots
fig_plots, axs_plots = plt.subplots(2, 2, figsize=(20, 16))
axs_plots = axs_plots.flatten()

# Reference labels for each method
ref_labels = ["A", "B", "C", "D", "E"]

# Create plots for each story type
for i, story_type in enumerate(story_types):
    ax = axs_plots[i]

    # Plot each method's data points
    for j, method in enumerate(methods):
        if method in data[story_type]:
            # Extract cross entropy and logit diff improvement values
            entries = data[story_type][method]
            cross_entropies = [entry["cross_entropy"] for entry in entries]
            # Use logit diff improvement values for y-axis
            logit_diffs_improvement = [
                (
                    entry["logit_diff_improvement"]
                    if "logit_diff_improvement" in entry
                    else entry["logit_diff"]
                )
                for entry in entries
            ]

            # Sort the points by cross_entropy for better line connections
            sorted_indices = np.argsort(cross_entropies)
            sorted_entropies = [cross_entropies[i] for i in sorted_indices]
            sorted_diffs = [logit_diffs_improvement[i] for i in sorted_indices]

            # Plot lines connecting points of the same method
            ax.plot(
                sorted_entropies,
                sorted_diffs,
                color=colors[j],
                alpha=0.5,
                linestyle="-",
                linewidth=1.5,
            )

            # Plot points
            sc = ax.scatter(
                cross_entropies,
                logit_diffs_improvement,
                color=colors[j],
                marker=markers[j],
                s=120,
                label=display_names[j],  # Use display name for label
                alpha=0.8,
                edgecolors="white",
                linewidth=0.8,
            )
            # Annotate the first point with the reference label
            if cross_entropies and logit_diffs_improvement:
                ax.annotate(
                    ref_labels[j],
                    (cross_entropies[0], logit_diffs_improvement[0]),
                    textcoords="offset points",
                    xytext=(10, 10),
                    ha="left",
                    fontsize=18,
                    fontweight="bold",
                    color=colors[j],
                    bbox=dict(
                        boxstyle="round,pad=0.2",
                        fc="white",
                        ec=colors[j],
                        lw=1.5,
                        alpha=0.8,
                    ),
                )

    # Add dotted horizontal line at x=9 with label "threshold"
    ax.axvline(x=9, color="#555555", linestyle="--", linewidth=1.5, alpha=0.7)
    ax.text(
        9.1,
        ax.get_ylim()[0] + 0.1 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
        "threshold",
        fontsize=16,
        color="#555555",
        fontweight="bold",
        ha="left",
        va="bottom",
        rotation=90,
    )

    # Set title and labels with larger font sizes
    ax.set_title(story_type, fontsize=28, fontweight="bold")
    ax.set_xlabel("Cross Entropy", fontsize=22, fontweight="bold")
    ax.set_ylabel("Logit Difference Improvement", fontsize=22, fontweight="bold")
    ax.tick_params(axis="both", which="major", labelsize=18)

    # Make x-axis more visible
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)
    ax.xaxis.set_tick_params(width=1.5, length=6)

# Create custom legend with improved styling
legend_elements = [
    Line2D(
        [0],
        [0],
        marker=markers[i],
        color="w",
        markerfacecolor=colors[i],
        markersize=12,
        label=display_names[i],
        markeredgecolor="white",
        markeredgewidth=0.8,
    )
    for i in range(len(methods))
]

fig_plots.legend(
    handles=legend_elements,
    loc="upper center",
    bbox_to_anchor=(0.5, 0.98),
    ncol=len(methods),
    fontsize=16,
    frameon=True,
    framealpha=0.9,
    edgecolor="#cccccc",
)

# Adjust spacing between subplots
plt.tight_layout()
plt.subplots_adjust(hspace=0.3, wspace=0.3, top=0.92)

# Save the plot figure
plt.savefig("diagram_stories_plots.png", dpi=300, bbox_inches="tight")

# FIGURE 2: Create a completely separate text-only figure with FIXED layout
fig_text, axs_text = plt.subplots(2, 2, figsize=(20, 16))
plt.subplots_adjust(hspace=0.5, wspace=0.5)  # Large spacing between subplots

# Flatten the axes for easier indexing
axs_text = axs_text.flatten()

# Constants for text layout
TEMPLATE_BOX_HEIGHT = 0.3  # Height of template box
CONTEXT_BOX_HEIGHT = 0.5  # Height of context box
FONTSIZE = 14  # Consistent font size for all text
TEMPLATE_Y = 0.8  # Y position for template (centered in top box)
CONTEXT_Y = 0.35  # Y position for contexts (centered in bottom box)
BOX_WIDTH = 0.8  # Width of both template and context boxes

# Process each story
for i, story_type in enumerate(story_types):
    ax = axs_text[i]

    # Turn off axis elements, keep only the frame
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines["top"].set_visible(True)
    ax.spines["right"].set_visible(True)
    ax.spines["bottom"].set_visible(True)
    ax.spines["left"].set_visible(True)
    ax.spines["top"].set_linewidth(1.5)
    ax.spines["right"].set_linewidth(1.5)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

    # Set white background
    ax.set_facecolor("white")

    # Set title
    ax.set_title(story_type, fontsize=24, fontweight="bold", pad=15)

    # Get the template and relevant information
    template = data[story_type]["info"]["template"]

    # Get desired_text and undesired_text from info if available
    desired_text = ""
    undesired_text = ""
    if "desired_text" in data[story_type]["info"]:
        desired_text = f" {make_bold(data[story_type]['info']['desired_text'])}"
    if "undesired_text" in data[story_type]["info"]:
        undesired_text = f" [{data[story_type]['info']['undesired_text']}]"

    # Create template text
    template_content = (
        "Template: "
        + template.replace("{0}", make_bold("context"))
        + desired_text
        + undesired_text
    )

    # Draw template box
    template_rect = plt.Rectangle(
        (0.5 - BOX_WIDTH / 2, TEMPLATE_Y + TEMPLATE_BOX_HEIGHT / 2),
        BOX_WIDTH,
        TEMPLATE_BOX_HEIGHT,
        transform=ax.transAxes,
        facecolor="#f5f5f5",
        edgecolor="#aaaaaa",
        linewidth=1,
        alpha=1.0,
    )
    ax.add_patch(template_rect)

    # Add template text
    ax.text(
        0.5,
        TEMPLATE_Y,
        template_content,
        transform=ax.transAxes,
        fontsize=FONTSIZE,
        ha="center",
        va="center",
        wrap=True,
        linespacing=1.3,
    )

    # Create context examples text
    examples_content = ["Contexts:"]

    # Add each method's context text
    for j, method in enumerate(methods):
        if method in data[story_type] and len(data[story_type][method]) > 0:
            example = select_example(data[story_type][method], 0)
            completion_words = get_completion_words(data[story_type][method], 0)
            line = f"({ref_labels[j]}) {display_names[j]}: {example}{completion_words}"
            examples_content.append(line)

    # Join all examples content
    examples_text = "\n".join(examples_content)

    # Draw context box
    context_rect = plt.Rectangle(
        (0.5 - BOX_WIDTH / 2, CONTEXT_Y - CONTEXT_BOX_HEIGHT / 2),
        BOX_WIDTH,
        CONTEXT_BOX_HEIGHT,
        transform=ax.transAxes,
        facecolor="white",
        edgecolor="#cccccc",
        linewidth=1,
        alpha=1.0,
    )
    ax.add_patch(context_rect)

    # Add context text
    ax.text(
        0.5,
        CONTEXT_Y,
        examples_text,
        transform=ax.transAxes,
        fontsize=FONTSIZE,
        ha="center",
        va="center",
        wrap=True,
        linespacing=1.5,
    )

    # Add divider line
    ax.axhline(y=0.5, color="#bbbbbb", linestyle="-", linewidth=1, alpha=0.5)

# Save the figure with strict bounds
plt.savefig("diagram_stories_text.png", dpi=300, bbox_inches="tight")

# FIGURE 3: Create a vertical layout for consistent width boxes
# Create one tall figure with each box stacked vertically
fig_vertical = plt.figure(figsize=(10, 30))  # Tall figure for vertical layout

# Constants for vertical layout
V_TEXT_WIDTH = 0.7  # Width for all text (consistent)
V_SPACING = 0.02  # Spacing between text sections
V_FONTSIZE = 14  # Consistent font size
V_LINE_HEIGHT = 0.05  # Height for each line of text

# Create a single axis that covers the entire figure
ax = fig_vertical.add_axes([0.1, 0.05, 0.8, 0.9])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")  # Hide the axis

# Start from the top
current_y = 0.95

# Process each story type
for i, story_type in enumerate(story_types):
    # Add story title
    ax.text(
        0.5,
        current_y,
        story_type,
        ha="center",
        va="top",
        fontsize=20,
        fontweight="bold",
    )
    current_y -= 2 * V_SPACING

    # Get the template and relevant information
    template = data[story_type]["info"]["template"]

    # Get desired_text and undesired_text from info if available
    desired_text = ""
    undesired_text = ""
    if "desired_text" in data[story_type]["info"]:
        desired_text = f" {make_bold(data[story_type]['info']['desired_text'])}"
    if "undesired_text" in data[story_type]["info"]:
        undesired_text = f" [{data[story_type]['info']['undesired_text']}]"

    # Create template text
    template_content = (
        "Template: "
        + template.replace("{0}", make_bold("context"))
        + desired_text
        + undesired_text
    )

    # Add template text with consistent width
    ax.text(
        0.5,
        current_y,
        template_content,
        fontsize=V_FONTSIZE,
        ha="center",
        va="top",
        wrap=True,
        linespacing=1.5,
        bbox=dict(facecolor="none", edgecolor="none", pad=0.0, boxstyle="square"),
        transform=ax.transAxes,
    )

    # Calculate height based on content length
    text_height = len(template_content) / 60 * V_LINE_HEIGHT  # Rough estimate
    current_y -= text_height + 3 * V_SPACING

    # Add horizontal separator line
    ax.axhline(
        y=current_y + V_SPACING,
        color="#999999",
        linestyle="-",
        linewidth=0.5,
        alpha=0.5,
        xmin=0.15,
        xmax=0.85,
    )

    # Add "Contexts:" label
    ax.text(
        0.5,
        current_y,
        "Contexts:",
        ha="center",
        va="top",
        fontsize=V_FONTSIZE,
        fontweight="bold",
    )
    current_y -= 2 * V_SPACING

    # Add each method's context
    for j, method in enumerate(methods):
        if method in data[story_type] and len(data[story_type][method]) > 0:
            example = select_example(data[story_type][method], 0)
            completion_words = get_completion_words(data[story_type][method], 0)
            context_content = (
                f"({ref_labels[j]}) {display_names[j]}: {example}{completion_words}"
            )

            # Add context text with consistent width
            ax.text(
                0.5,
                current_y,
                context_content,
                fontsize=V_FONTSIZE,
                ha="center",
                va="top",
                wrap=True,
                linespacing=1.5,
                bbox=dict(
                    facecolor="none", edgecolor="none", pad=0.0, boxstyle="square"
                ),
                transform=ax.transAxes,
            )

            # Calculate height based on content length
            text_height = len(context_content) / 60 * V_LINE_HEIGHT  # Rough estimate
            current_y -= text_height + 2 * V_SPACING

    # Add extra spacing between story types
    current_y -= 4 * V_SPACING

    # Add horizontal separator line between stories
    ax.axhline(
        y=current_y + 2 * V_SPACING,
        color="#999999",
        linestyle="-",
        linewidth=1.0,
        alpha=0.7,
        xmin=0.1,
        xmax=0.9,
    )

# Save the vertical layout figure
plt.savefig("diagram_stories_vertical.png", dpi=300, bbox_inches="tight")

plt.close("all")  # Close all figures

# Show that we're done
print(
    "Generated three figures: diagram_stories_plots.png, diagram_stories_text.png, and diagram_stories_vertical.png"
)


def save_individual_plots(data: Dict[str, Any], output_dir: str = "plots") -> None:
    """
    Save each plot separately in the specified output directory.

    Args:
        data: The data dictionary containing the plot information
        output_dir: Directory to save the plots (default: 'plots')
    """
    os.makedirs(output_dir, exist_ok=True)

    # Create plots for each story type
    for i, story_type in enumerate(story_types):
        # Create a new figure for each story type
        fig, ax = plt.subplots(figsize=(10, 8))

        # Plot each method's data points
        for j, method in enumerate(methods):
            if method in data[story_type]:
                # Extract cross entropy and logit diff improvement values
                entries = data[story_type][method]
                cross_entropies = [entry["cross_entropy"] for entry in entries]
                # Use logit diff improvement values for y-axis
                logit_diffs_improvement = [
                    (
                        entry["logit_diff_improvement"]
                        if "logit_diff_improvement" in entry
                        else entry["logit_diff"]
                    )
                    for entry in entries
                ]

                # Sort the points by cross_entropy for better line connections
                sorted_indices = np.argsort(cross_entropies)
                sorted_entropies = [cross_entropies[i] for i in sorted_indices]
                sorted_diffs = [logit_diffs_improvement[i] for i in sorted_indices]

                # Plot lines connecting points of the same method
                ax.plot(
                    sorted_entropies,
                    sorted_diffs,
                    color=colors[j],
                    alpha=0.5,
                    linestyle="-",
                    linewidth=1.5,
                )

                # Plot points
                ax.scatter(
                    cross_entropies,
                    logit_diffs_improvement,
                    color=colors[j],
                    marker=markers[j],
                    s=120,
                    label=display_names[j],  # Use display name for label
                    alpha=0.8,
                    edgecolors="white",
                    linewidth=0.8,
                )
                # Annotate the first point with the reference label
                if cross_entropies and logit_diffs_improvement:
                    ax.annotate(
                        ref_labels[j],
                        (cross_entropies[0], logit_diffs_improvement[0]),
                        textcoords="offset points",
                        xytext=(10, 10),
                        ha="left",
                        fontsize=18,
                        fontweight="bold",
                        color=colors[j],
                        bbox=dict(
                            boxstyle="round,pad=0.2",
                            fc="white",
                            ec=colors[j],
                            lw=1.5,
                            alpha=0.8,
                        ),
                    )

        # Add dotted horizontal line at x=9 with label "threshold"
        ax.axvline(x=9, color="#555555", linestyle="--", linewidth=1.5, alpha=0.7)
        ax.text(
            9.1,
            ax.get_ylim()[0] + 0.1 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
            "threshold",
            fontsize=16,
            color="#555555",
            fontweight="bold",
            ha="left",
            va="bottom",
            rotation=90,
        )

        # Set title and labels with larger font sizes
        # ax.set_title(story_type, fontsize=28, fontweight='bold')
        ax.set_xlabel("Cross Entropy", fontsize=22, fontweight="bold")
        ax.set_ylabel("Logit Difference Improvement", fontsize=22, fontweight="bold")
        ax.tick_params(axis="both", which="major", labelsize=18)

        # Make x-axis more visible
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_linewidth(1.5)
        ax.spines["left"].set_linewidth(1.5)
        ax.xaxis.set_tick_params(width=1.5, length=6)

        # Create custom legend
        legend_elements = [
            Line2D(
                [0],
                [0],
                marker=markers[i],
                color="w",
                markerfacecolor=colors[i],
                markersize=12,
                label=display_names[i],
                markeredgecolor="white",
                markeredgewidth=0.8,
            )
            for i in range(len(methods))
        ]

        fig.legend(
            handles=legend_elements,
            loc="upper center",
            bbox_to_anchor=(0.5, 0.98),
            ncol=len(methods),
            fontsize=16,
            frameon=True,
            framealpha=0.9,
            edgecolor="#cccccc",
        )

        # Adjust layout
        plt.tight_layout()
        plt.subplots_adjust(top=0.92)

        # Save the individual plot
        plt.savefig(
            os.path.join(
                output_dir,
                f'diagram_stories_{story_type.lower().replace(" ", "_")}.png',
            ),
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()


def save_individual_plots_new(
    data: Dict[str, Any],
    story_types,
    methods,
    colors,
    markers,
    ref_labels,
    output_dir: str = "plots",
) -> None:
    """
    Write one PDF/PNG per story panel – with neither title nor legend.
    """
    os.makedirs(output_dir, exist_ok=True)

    for story_type in story_types:
        fig, ax = plt.subplots(figsize=(10, 8))

        # ----- data series ---------------------------------------------------
        for j, method in enumerate(methods):
            if method not in data[story_type]:
                continue

            entries = data[story_type][method]
            xvals = [e["cross_entropy"] for e in entries]
            yvals = [e.get("logit_diff_improvement", e["logit_diff"]) for e in entries]
            # keep lines monotone in x for nicer connection
            order = np.argsort(xvals)
            ax.plot(
                np.array(xvals)[order],
                np.array(yvals)[order],
                color=colors[j],
                alpha=0.5,
                lw=1.5,
            )

            ax.scatter(
                xvals,
                yvals,
                color=colors[j],
                marker=markers[j],
                s=120,
                alpha=0.8,
                edgecolors="white",
                linewidth=0.8,
            )

            # label first point
            if xvals:
                ax.annotate(
                    ref_labels[j],
                    (xvals[0], yvals[0]),
                    xytext=(10, 10),
                    textcoords="offset points",
                    ha="left",
                    fontsize=18,
                    fontweight="bold",
                    color=colors[j],
                    bbox=dict(
                        boxstyle="round,pad=0.2",
                        fc="white",
                        ec=colors[j],
                        lw=1.5,
                        alpha=0.8,
                    ),
                )

        # ----- threshold -----------------------------------------------------
        ax.axvline(9, ls="--", lw=1.5, color="#555555", alpha=0.7)
        ax.text(
            9.1,
            ax.get_ylim()[0] + 0.1 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
            "threshold",
            rotation=90,
            va="bottom",
            ha="left",
            fontsize=16,
            fontweight="bold",
            color="#555555",
        )

        # ----- cosmetics -----------------------------------------------------
        ax.set_xlabel("Cross Entropy", fontsize=22, fontweight="bold")
        ax.set_ylabel("Logit Difference Improvement", fontsize=22, fontweight="bold")
        ax.tick_params(axis="both", which="major", labelsize=18)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_linewidth(1.5)
        ax.spines["left"].set_linewidth(1.5)
        ax.xaxis.set_tick_params(width=1.5, length=6)

        # ----- save ----------------------------------------------------------
        safe_name = story_type.lower().replace(" ", "_")
        fig.tight_layout()
        fig.savefig(
            os.path.join(output_dir, f"{safe_name}.pdf"), bbox_inches="tight"
        )  # vector preferred
        plt.close(fig)


def save_shared_legend(methods, colors, markers, filename: str = "legend.pdf") -> None:
    """
    Build and export the single legend as a standalone, transparent PDF.
    """
    # Reorder methods for legend

    legend_methods = [display_name for _, _, display_name, _ in result_sets_config]
    legend_colors = [color for _, _, _, color in result_sets_config]
    legend_markers = ["s", "X", "D", "^", "o"]

    legend_elements = [
        Line2D(
            [0],
            [0],
            marker=legend_markers[i],
            markersize=12,
            markerfacecolor=legend_colors[i],
            markeredgecolor="white",
            markeredgewidth=0.8,
            linestyle="none",
            label=legend_methods[i],
        )
        for i in range(len(legend_methods))
    ]

    fig = plt.figure(figsize=(4, 1.2))
    fig.legend(
        handles=legend_elements, ncol=len(legend_methods), loc="center", frameon=False
    )
    fig.savefig(filename, bbox_inches="tight", transparent=True)
    plt.close(fig)


# Call the function to save individual plots
save_individual_plots_new(data, story_types, methods, colors, markers, ref_labels)
save_shared_legend(methods, colors, markers)


def save_latex_commands_old(
    data: Dict[str, Any], output_file: str = "plots/latex.txt"
) -> None:
    """
    Save LaTeX commands for the story texts with the exact specified format.
    """
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    latex_output = ""

    # Create commands for each story type
    for i, story_type in enumerate(story_types):
        latex_output += f"\\newcommand{{\\StoryText{chr(65+i)}}}{{%\n"

        # Get template and desired/undesired text
        template = data[story_type]["info"]["template"]
        desired_text = data[story_type]["info"].get("desired_text", "")
        undesired_text = data[story_type]["info"].get("undesired_text", "")

        # Format template with context placeholder
        template_parts = template.split("{0}")
        formatted_template = f"{template_parts[0]}\\textbf{{\\textless context\\textgreater}} {template_parts[1]}"

        # Add desired/undesired text
        if desired_text and undesired_text:
            formatted_template += f" \\textcolor{{green!50!black}}{{{desired_text}}} / \\textcolor{{red!70!black}}{{{undesired_text}}}"

        # Add template to output
        latex_output += f"  \\textbf{{Template:}} {formatted_template}\\\\[2pt]\n"

        # Add contexts header
        latex_output += "  \\textbf{Contexts:}\n"
        latex_output += (
            "  \\begin{itemize}[leftmargin=1.4cm,labelsep=0.1cm,nosep,itemsep=1.5pt]\n"
        )

        # Add contexts for each method using the existing methods and display_names
        for j, method in enumerate(methods):
            if method in data[story_type] and len(data[story_type][method]) > 0:
                context = select_example(data[story_type][method], 0)
                latex_output += (
                    f"    \\item[({ref_labels[j]}) {display_names[j]}:] {context}\n"
                )

        # Close itemize and command
        latex_output += "  \\end{itemize}\n"
        latex_output += "}\n\n"

    # Write to file
    with open(output_file, "w") as f:
        f.write(latex_output)

    print(f"Saved LaTeX commands to {output_file}")


def save_latex_commands(
    data: Dict[str, Any], output_file: str = "plots/latex.txt"
) -> None:
    """
    Save LaTeX commands for every story so that the output **exactly** matches the required
    template while pulling the actual content from ``data``.

    The produced file will contain four commands – ``\StoryTextA`` … ``\StoryTextD`` –
    each wrapping the story snippet in the format

        \newcommand{\StoryTextX}{%
          \textbf{Template:} …\\[2pt]
          \textbf{Contexts:}
          \begin{itemize}[leftmargin=1.4cm,labelsep=0.1cm,nosep,itemsep=1.5pt]
            \item[(A) Human:] …
            …
          \end{itemize}
        }

    Notes
    -----
    * Two‑space indentation is used throughout to match the target layout.
    * The placeholder ``{0}`` in the template is replaced by ``\textbf{\textless context\textgreater}``.
    * If *both* ``desired_text`` **and** ``undesired_text`` are present, they are appended in the
      “green / red” colour scheme just like the reference example.
    * Only methods that have at least one entry for the given story are emitted.
    """

    import os

    # ensure the output directory exists
    os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)

    lines: list[str] = []

    for i, story_type in enumerate(story_types):
        cmd_letter = chr(65 + i)  # A, B, C, …

        # ---------- header --------------------------------------------------
        lines.append(f"\\newcommand{{\\StoryText{cmd_letter}}}{{%")

        # Template line ------------------------------------------------------
        template_raw = data[story_type]["info"]["template"].replace(
            "{0}", "\\textbf{\\textless context\\textgreater}"
        )
        desired = data[story_type]["info"].get("desired_text", "")
        undesired = data[story_type]["info"].get("undesired_text", "")

        if desired and undesired:
            template_raw += (
                f" \\textcolor{{green!50!black}}{{{desired}}} / "
                f"\\textcolor{{red!70!black}}{{{undesired}}}"
            )

        lines.append(f"  \\textbf{{Template:}} {template_raw}\\\\[2pt]")

        # Context list -------------------------------------------------------
        lines.append("  \\textbf{Contexts:}")
        lines.append(
            "  \\begin{itemize}[leftmargin=1.4cm,labelsep=0.1cm,nosep,itemsep=1.5pt]"
        )

        for j, method in enumerate(methods):
            if method not in data[story_type] or not data[story_type][method]:
                continue
            example = select_example(data[story_type][method], 0)
            label = ref_labels[j]
            name = display_names[j]
            lines.append(f"    \\item[({label}) {name}:] {example}")

        lines.append("  \\end{itemize}")
        lines.append("}")  # close \newcommand
        lines.append("")  # blank line between commands

    with open(output_file, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))

    print(f"Saved LaTeX commands to {output_file}")


# Call the function to save LaTeX commands
save_latex_commands(data)
