"""Generate figures used by the ICML workshop paper."""

from __future__ import annotations

import csv
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, FancyBboxPatch


ROOT = Path(__file__).resolve().parents[1]
OUT = ROOT / "paper" / "figures"


LABELS = {
    "unguided": "Unguided",
    "family_soft": "Soft family",
    "family_guided": "Hard gate",
    "family_top_m": "Top-m quota",
    "family_rrf": "RRF",
}


def setup() -> None:
    OUT.mkdir(parents=True, exist_ok=True)
    plt.rcParams.update(
        {
            "font.family": "serif",
            "font.size": 9,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
        }
    )


def rounded_box(ax, xy, width, height, text, face, edge="#404040"):
    box = FancyBboxPatch(
        xy,
        width,
        height,
        boxstyle="round,pad=0.018,rounding_size=0.02",
        linewidth=0.9,
        facecolor=face,
        edgecolor=edge,
    )
    ax.add_patch(box)
    ax.text(
        xy[0] + width / 2,
        xy[1] + height / 2,
        text,
        ha="center",
        va="center",
        color="#1f1f1f",
        fontsize=9.2,
    )


def arrow(ax, start, end, color="#404040", style="-", connectionstyle="arc3"):
    ax.add_patch(
        FancyArrowPatch(
            start,
            end,
            arrowstyle="-|>",
            mutation_scale=13,
            linewidth=1.25,
            linestyle=style,
            color=color,
            connectionstyle=connectionstyle,
        )
    )


def candidate_path() -> None:
    fig, ax = plt.subplots(figsize=(7.2, 3.05))
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis("off")

    boxes = [
        (0.04, 0.67, 0.22, 0.18, "1. Proof state\nGoal + context", "#dcecf7"),
        (0.39, 0.67, 0.22, 0.18, "2. State view\nState-visible fields", "#dcecf7"),
        (0.74, 0.67, 0.22, 0.18, "3. Candidate pool\nRetrieved tactics", "#e8f1df"),
        (0.74, 0.36, 0.22, 0.18, "4. Ranking\nSimilarity + family", "#f4ead6"),
        (0.39, 0.36, 0.22, 0.18, "5. Top-k list\nTactics to try", "#ece3f1"),
        (0.04, 0.36, 0.22, 0.18, "6. Lean check\nAccept or reject", "#f7dfdf"),
    ]
    for b in boxes:
        rounded_box(ax, b[:2], b[2], b[3], b[4], b[5])

    arrow(ax, (0.26, 0.76), (0.39, 0.76))
    arrow(ax, (0.61, 0.76), (0.74, 0.76))
    arrow(ax, (0.85, 0.67), (0.85, 0.54))
    arrow(ax, (0.74, 0.45), (0.61, 0.45))
    arrow(ax, (0.39, 0.45), (0.26, 0.45))

    rounded_box(
        ax,
        (0.24, 0.08),
        0.25,
        0.17,
        "Future tactic fields\npremises + tactic AST",
        "#f5f5f5",
        edge="#8b8b8b",
    )
    rounded_box(
        ax,
        (0.58, 0.08),
        0.25,
        0.17,
        "Hindsight check\nnot a main input",
        "#f5f5f5",
        edge="#8b8b8b",
    )
    arrow(ax, (0.49, 0.165), (0.58, 0.165), color="#8b8b8b", style="--")
    ax.text(0.365, 0.035, "available only after replay", ha="center", va="center", fontsize=8.2, color="#555555")

    fig.savefig(OUT / "candidate_path.pdf", bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)


def trace_vs_lean() -> None:
    path = ROOT / "results" / "tables" / "execution_accepted_alternatives.csv"
    order = ["unguided", "family_soft", "family_rrf", "family_top_m", "family_guided"]
    rows = {}
    with path.open() as f:
        for row in csv.DictReader(f):
            rows[row["strategy"]] = row

    trace = [float(rows[s]["proxy_exact_at_5_on_sample"]) for s in order]
    lean = [float(rows[s]["lean_accept_at_5"]) for s in order]
    accepted_alt = [float(rows[s]["accepted_not_gold_at_5"]) for s in order]
    labels = [LABELS[s] for s in order]

    fig, ax = plt.subplots(figsize=(4.6, 2.55))
    x = list(range(len(order)))
    width = 0.28
    ax.bar(
        [i - width / 2 for i in x],
        trace,
        width,
        label="Trace exact@5",
        color="#5b8db8",
        edgecolor="#2f5c83",
        linewidth=0.5,
    )
    ax.bar(
        [i + width / 2 for i in x],
        lean,
        width,
        label="Lean accept@5",
        color="#d08a3d",
        edgecolor="#945d20",
        linewidth=0.5,
    )
    ax.plot(
        x,
        accepted_alt,
        marker="o",
        linewidth=1.2,
        markersize=3.5,
        color="#4b7f52",
        label="Accepted non-trace",
    )

    ax.set_ylim(0, 0.78)
    ax.set_ylabel("Fraction of 500 states")
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=18, ha="right")
    ax.grid(axis="y", linewidth=0.35, alpha=0.35)
    ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.20), ncol=3, frameon=False)
    fig.savefig(OUT / "trace_vs_lean_acceptance.pdf", bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)


def trace_lean_scatter() -> None:
    path = ROOT / "results" / "tables" / "execution_accepted_alternatives.csv"
    order = ["unguided", "family_soft", "family_rrf", "family_top_m", "family_guided"]
    rows = {}
    with path.open() as f:
        for row in csv.DictReader(f):
            rows[row["strategy"]] = row

    fig, ax = plt.subplots(figsize=(3.35, 2.75))
    colors = {
        "unguided": "#2f6f9f",
        "family_soft": "#c8792f",
        "family_rrf": "#4b7f52",
        "family_top_m": "#8a6db1",
        "family_guided": "#9d4f4f",
    }
    label_offsets = {
        "unguided": (-8, 12, "right"),
        "family_soft": (-8, 10, "right"),
        "family_rrf": (12, 8, "left"),
        "family_top_m": (14, 11, "left"),
        "family_guided": (13, 8, "left"),
    }
    for strategy in order:
        row = rows[strategy]
        x = float(row["proxy_exact_at_5_on_sample"])
        y = float(row["lean_accept_at_5"])
        size = 320 * float(row["accepted_not_gold_at_5"]) + 35
        ax.scatter(
            x,
            y,
            s=size,
            color=colors[strategy],
            edgecolor="white",
            linewidth=0.7,
            alpha=0.92,
            zorder=3,
        )
        dx, dy, ha = label_offsets[strategy]
        ax.annotate(
            LABELS[strategy],
            xy=(x, y),
            xytext=(dx, dy),
            textcoords="offset points",
            ha=ha,
            va="center",
            fontsize=8,
            bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.65, "pad": 0.2},
        )

    ax.set_xlim(0.055, 0.185)
    ax.set_ylim(0.54, 0.70)
    ax.set_xlabel("Trace exact@5")
    ax.set_ylabel("Lean accept@5")
    ax.grid(linewidth=0.35, alpha=0.35)

    fig.savefig(OUT / "trace_lean_scatter.pdf", bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)


def main() -> None:
    setup()
    candidate_path()
    trace_vs_lean()
    trace_lean_scatter()
    print(f"Wrote figures to {OUT}")


if __name__ == "__main__":
    main()
