"""Plotting utilities for perturbation experiments."""

import argparse
import json
from pathlib import Path
from typing import Iterable, List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def load_records(path: Path) -> List[dict]:
    """Load a list of perturbation records from JSON."""
    data = json.loads(path.read_text())
    if not isinstance(data, list):
        raise ValueError(f"Expected a list of records in {path}; found {type(data)!r}")
    if not data:
        raise ValueError(f"No records found in {path}")
    return data


def plot_records(
    records: Iterable[dict],
    *,
    output: Path | None = None,
    show: bool = False,
    title: str = "Perturbation Amplitude vs. Perplexity",
) -> None:
    """Plot perturbation amplitude vs. perplexity using seaborn."""

    records = list(records)
    if not records:
        raise ValueError("plot_records received no records to plot")

    df = pd.DataFrame.from_records(records)
    required_columns = {"perturbation_norm", "perturbed_perplexity", "baseline_perplexity"}
    missing = required_columns - set(df.columns)
    if missing:
        raise ValueError(f"Records missing required fields: {sorted(missing)}")

    df = df.copy()
    if "rank" in df.columns:
        df["rank"] = df["rank"].astype(str)

    baseline = float(df.iloc[0]["baseline_perplexity"])

    fig, ax = plt.subplots(figsize=(8, 6))
    hue = "rank" if "rank" in df.columns else None
    sns.scatterplot(
        data=df,
        x="perturbation_norm",
        y="perturbed_perplexity",
        hue=hue,
        palette="tab20" if hue else None,
        alpha=0.7,
        edgecolor="none",
        ax=ax,
    )

    ax.axhline(
        baseline,
        color="gray",
        linestyle="--",
        linewidth=1,
        label="Baseline perplexity",
    )
    ax.set_xlabel("Perturbation Frobenius norm (||Δ||)")
    ax.set_ylabel("Perplexity")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_title(title)
    ax.grid(True, linewidth=0.3, alpha=0.5)

    handles, labels = ax.get_legend_handles_labels()
    if handles:
        ax.legend(handles=handles, labels=labels, frameon=False, loc="upper left")

    fig.tight_layout()

    if output is not None:
        output.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(output, dpi=200)
        print(f"Saved plot to {output}")

    if show or output is None:
        plt.show()
    else:
        plt.close(fig)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Plot perturbation amplitude vs perplexity from JSON records."
    )
    parser.add_argument(
        "json_path",
        type=Path,
        help="Path to the JSON results produced by grid_search.py",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Optional path to save the plot image",
    )
    parser.add_argument(
        "--title",
        default="Perturbation Amplitude vs. Perplexity",
        help="Title for the plot",
    )
    parser.add_argument(
        "--no-show",
        action="store_true",
        help="Do not open an interactive window",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    records = load_records(args.json_path)
    plot_records(records, output=args.output, show=not args.no_show, title=args.title)


if __name__ == "__main__":  # pragma: no cover
    main()
