"""Plot heldout/eval loss ratio vs training steps for a sweep."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

RANK_DIR_PREFIX = "r"
RANK_BASE_COLORS = [
    "#1f77b4",
    "#2ca02c",
    "#9467bd",
    "#d62728",
    "#e377c2",
    "#bcbd22",
]
FULL_BASE_COLOR = "#8c564b"


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Plot heldout/eval loss ratio vs training steps."
    )
    parser.add_argument(
        "sweep_dir",
        type=Path,
        help="Path to the sweep directory (e.g. sweep_1614967).",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Output path for the plot PNG.",
    )
    parser.add_argument(
        "--csv",
        type=Path,
        default=None,
        help="Optional path to save the plotted points as CSV.",
    )
    parser.add_argument(
        "--include-full",
        action="store_true",
        help="Include full fine-tuning results if the 'full' directory exists.",
    )
    return parser.parse_args()


def load_loss_series(metrics_path: Path) -> List[dict]:
    data = json.loads(metrics_path.read_text())
    records = data.get("forgetting_curve") or data.get("records") or []
    series: List[dict] = []
    for record in records:
        step = record.get("step")
        eval_loss = record.get("eval_loss")
        heldout_loss = record.get("heldout_eval_loss")
        if step is None or eval_loss is None or heldout_loss is None:
            continue
        series.append(
            {
                "step": int(step),
                "eval_loss": float(eval_loss),
                "heldout_eval_loss": float(heldout_loss),
            }
        )
    if not series:
        raise ValueError(f"No eval/heldout losses found in {metrics_path}")
    return sorted(series, key=lambda item: item["step"])


def collect_series(sweep_dir: Path, include_full: bool) -> List[Tuple[str, List[dict]]]:
    series_list: List[Tuple[str, List[dict]]] = []
    rank_dirs = sorted(
        [
            p
            for p in sweep_dir.iterdir()
            if p.is_dir() and p.name.startswith(RANK_DIR_PREFIX)
        ],
        key=lambda p: int(p.name[len(RANK_DIR_PREFIX):]),
    )

    for rank_dir in rank_dirs:
        metrics_path = rank_dir / "metrics.json"
        if not metrics_path.exists():
            raise FileNotFoundError(f"Missing metrics.json in {rank_dir}")
        series_list.append((rank_dir.name, load_loss_series(metrics_path)))

    if include_full:
        full_dir = sweep_dir / "full"
        metrics_path = full_dir / "metrics.json"
        if metrics_path.exists():
            series_list.append(("full", load_loss_series(metrics_path)))

    if not series_list:
        raise ValueError(f"No series found in {sweep_dir}")
    return series_list


def ratio_series(series: List[dict]) -> List[dict]:
    ratios: List[dict] = []
    for point in series:
        eval_loss = point["eval_loss"]
        if eval_loss <= 0:
            raise ValueError(
                f"Non-positive eval_loss at step {point['step']}: {eval_loss}"
            )
        ratios.append(
            {
                "step": point["step"],
                "ratio": point["heldout_eval_loss"] /,
            }
        )
    return ratios


def plot_series(series_list: List[Tuple[str, List[dict]]], output: Path) -> None:
    fig, ax = plt.subplots(figsize=(9, 6))
    legend_handles: List[Line2D] = []
    rank_index = 0

    for name, series in series_list:
        if name == "full":
            color = FULL_BASE_COLOR
        else:
            color = RANK_BASE_COLORS[rank_index % len(RANK_BASE_COLORS)]
            rank_index += 1

        ratios = ratio_series(series)
        x_vals = [point["step"] for point in ratios]
        y_vals = [point["ratio"] for point in ratios]

        ax.plot(
            x_vals,
            y_vals,
            color=color,
            linewidth=1.6,
            alpha=0.9,
        )
        ax.scatter(
            x_vals,
            y_vals,
            color=color,
            s=38,
            edgecolor="none",
            alpha=0.9,
        )
        legend_handles.append(
            Line2D(
                [0],
                [0],
                marker="o",
                color=color,
                markerfacecolor=color,
                markersize=7,
                label=name,
                linestyle="-",
            )
        )

    ax.set_xlabel("Training step")
    ax.set_ylabel("Heldout / eval loss")
    ax.set_title("Heldout-to-eval loss ratio over training")
    ax.set_yscale("log")
    ax.grid(True, linewidth=0.3, alpha=0.5)
    ax.legend(handles=legend_handles, frameon=False, title="Run")

    fig.tight_layout()
    output.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output, dpi=200)
    plt.close(fig)


def main() -> None:
    args = parse_args()
    if args.output is None:
        args.output = args.sweep_dir / "loss_ratio_vs_steps.png"
    if args.csv is None:
        args.csv = args.sweep_dir / "loss_ratio_vs_steps.csv"
    if not args.include_full:
        args.include_full = (args.sweep_dir / "full" / "metrics.json").exists()

    series_list = collect_series(args.sweep_dir, args.include_full)
    rows: List[Dict[str, float | int | str]] = []
    for name, series in series_list:
        ratios = ratio_series(series)
        for point, ratio_point in zip(series, ratios):
            rows.append(
                {
                    "run": name,
                    "step": point["step"],
                    "eval_loss": point["eval_loss"],
                    "heldout_eval_loss": point["heldout_eval_loss"],
                    "heldout_over_eval": ratio_point["ratio"],
                }
            )
    rows.sort(key=lambda item: (item["run"], item["step"]))
    header = rows[0].keys() if rows else []
    with args.csv.open("w", encoding="utf-8") as handle:
        handle.write(",".join(header) + "\n")
        for row in rows:
            handle.write(
                f"{row['run']},{row['step']},{row['eval_loss']},{row['heldout_eval_loss']},{row['heldout_over_eval']}\n"
            )
    print(f"Saved CSV to {args.csv}")

    plot_series(series_list, args.output)
    print(f"Saved plot to {args.output}")


if __name__ == "__main__":
    main()
