"""Plot heldout eval loss vs eval loss with training progress for a sweep."""

from __future__ import annotations

import argparse
import json
from bisect import bisect_right
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import animation, cm, colors
from matplotlib.lines import Line2D

RANK_DIR_PREFIX = "r"
RANK_BASE_COLORS = [
    "#1f77b4",
    "#2ca02c",
    "#9467bd",
    "#d62728",
    "#e377c2",
    "#bcbd22",
]
FULL_BASE_COLOR = "#8c564b"
PROGRESS_MIN = 0.15
PROGRESS_MAX = 0.95


def gradient_cmap(base_color: str) -> colors.LinearSegmentedColormap:
    base = np.array(colors.to_rgb(base_color))
    light = 1 - (1 - base) * 0.35
    dark = base * 0.65
    return colors.LinearSegmentedColormap.from_list("", [light, dark])


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Plot heldout eval loss vs eval loss with progress coloring."
    )
    parser.add_argument(
        "sweep_dir",
        type=Path,
        help="Path to the sweep directory (e.g. sweep_1614967).",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Output path for the plot PNG.",
    )
    parser.add_argument(
        "--csv",
        type=Path,
        default=None,
        help="Optional path to save the plotted points as CSV.",
    )
    parser.add_argument(
        "--include-full",
        action="store_true",
        help="Include full fine-tuning results if the 'full' directory exists.",
    )
    parser.add_argument(
        "--animate",
        action="store_true",
        help="Save an animation (GIF/MP4) instead of a static PNG.",
    )
    parser.add_argument(
        "--fps",
        type=int,
        default=8,
        help="Frames per second for animations.",
    )
    return parser.parse_args()


def load_loss_series(metrics_path: Path) -> List[dict]:
    data = json.loads(metrics_path.read_text())
    records = data.get("forgetting_curve") or data.get("records") or []
    series: List[dict] = []
    for record in records:
        step = record.get("step")
        eval_loss = record.get("eval_loss")
        heldout_loss = record.get("heldout_eval_loss")
        if step is None or eval_loss is None or heldout_loss is None:
            continue
        series.append(
            {
                "step": int(step),
                "eval_loss": float(eval_loss),
                "heldout_eval_loss": float(heldout_loss),
            }
        )
    if not series:
        raise ValueError(f"No eval/heldout losses found in {metrics_path}")
    return sorted(series, key=lambda item: item["step"])


def collect_series(sweep_dir: Path, include_full: bool) -> List[Tuple[str, List[dict]]]:
    series_list: List[Tuple[str, List[dict]]] = []
    rank_dirs = sorted(
        [
            p
            for p in sweep_dir.iterdir()
            if p.is_dir() and p.name.startswith(RANK_DIR_PREFIX)
        ],
        key=lambda p: int(p.name[len(RANK_DIR_PREFIX):]),
    )

    for rank_dir in rank_dirs:
        metrics_path = rank_dir / "metrics.json"
        if not metrics_path.exists():
            raise FileNotFoundError(f"Missing metrics.json in {rank_dir}")
        series_list.append((rank_dir.name, load_loss_series(metrics_path)))

    if include_full:
        full_dir = sweep_dir / "full"
        metrics_path = full_dir / "metrics.json"
        if metrics_path.exists():
            series_list.append(("full", load_loss_series(metrics_path)))

    if not series_list:
        raise ValueError(f"No series found in {sweep_dir}")
    return series_list


def iter_series_with_colors(
    series_list: List[Tuple[str, List[dict]]],
) -> List[Tuple[str, List[dict], str]]:
    colored_series: List[Tuple[str, List[dict], str]] = []
    rank_index = 0
    for name, series in series_list:
        if name == "full":
            color = FULL_BASE_COLOR
        else:
            color = RANK_BASE_COLORS[rank_index % len(RANK_BASE_COLORS)]
            rank_index += 1
        colored_series.append((name, series, color))
    return colored_series


def plot_series(series_list: List[Tuple[str, List[dict]]], output: Path) -> None:
    fig, ax = plt.subplots(figsize=(9, 6))
    legend_handles: List[Line2D] = []

    for name, series, base_color in iter_series_with_colors(series_list):
        cmap = gradient_cmap(base_color)

        colors_for_points = [
            cmap(v) for v in np.linspace(PROGRESS_MIN, PROGRESS_MAX, len(series))
        ]
        x_vals = [point["eval_loss"] for point in series]
        y_vals = [point["heldout_eval_loss"] for point in series]

        if len(series) > 1:
            sns.regplot(
                x=x_vals,
                y=y_vals,
                ax=ax,
                ci=None,
                scatter=False,
                line_kws={
                    "color": cmap(0.85),
                    "alpha": 0.3,
                    "linewidth": 1.0,
                },
                color=cmap(0.85),
            )
        ax.scatter(
            x_vals,
            y_vals,
            c=colors_for_points,
            marker=".",
            s=42,
            edgecolor="none",
        )
        legend_handles.append(
            Line2D(
                [0],
                [0],
                marker=".",
                color="none",
                markerfacecolor=cmap(0.85),
                markersize=8,
                label=name,
            )
        )

    ax.set_xlabel("Non-trained loss")
    ax.set_ylabel("trained (validation) loss")
    ax.set_title(
        "Non-trained loss vs trained (validation) loss (training progress)"
    )
    # ax.set_yscale("log")
    # ax.set_xscale("log")
    ax.set_ylim(0, 1)
    ax.grid(True, linewidth=0.3, alpha=0.5)
    ax.legend(handles=legend_handles, frameon=False, title="Run")

    progress_norm = colors.Normalize(vmin=PROGRESS_MIN, vmax=PROGRESS_MAX)
    progress_map = cm.ScalarMappable(
        norm=progress_norm, cmap=plt.get_cmap("Greys"))
    progress_map.set_array([])
    cbar = fig.colorbar(progress_map, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Training progress")
    cbar.set_ticks([PROGRESS_MIN, PROGRESS_MAX])
    cbar.set_ticklabels(["early", "late"])

    fig.tight_layout()
    output.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output, dpi=200)
    plt.close(fig)


def _loss_limits(series_list: List[Tuple[str, List[dict]]]) -> tuple[tuple[float, float], tuple[float, float]]:
    eval_vals: List[float] = []
    heldout_vals: List[float] = []
    for _, series in series_list:
        for point in series:
            eval_vals.append(point["eval_loss"])
            heldout_vals.append(point["heldout_eval_loss"])
    if not eval_vals or not heldout_vals:
        raise ValueError("No loss values available for plotting.")
    if min(eval_vals) <= 0 or min(heldout_vals) <= 0:
        raise ValueError("Loss values must be positive for log scaling.")
    x_min = min(eval_vals) * 0.9
    x_max = max(eval_vals) * 1.1
    y_min = min(heldout_vals) * 0.9
    y_max = max(heldout_vals) * 1.1
    return (x_min, x_max), (y_min, y_max)


def plot_series_animated(
    series_list: List[Tuple[str, List[dict]]], output: Path, fps: int
) -> None:
    fig, ax = plt.subplots(figsize=(9, 6))
    legend_handles: List[Line2D] = []

    steps = sorted({point["step"]
                   for _, series in series_list for point in series})
    if not steps:
        raise ValueError("No evaluation steps available for animation.")

    (x_min, x_max), (y_min, y_max) = _loss_limits(series_list)
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_xlabel("Non-trained loss")
    ax.set_ylabel("trained (validation) loss")
    ax.set_title(
        "Non-trained loss vs trained (validation) loss (training progress)"
    )
    ax.grid(True, linewidth=0.3, alpha=0.5)

    artists: List[Tuple[Line2D, object,
                        List[int], List[float], List[float]]] = []
    for name, series, color in iter_series_with_colors(series_list):
        steps_for_series = [point["step"] for point in series]
        x_vals = [point["eval_loss"] for point in series]
        y_vals = [point["heldout_eval_loss"] for point in series]

        line, = ax.plot([], [], color=color, alpha=0.35, linewidth=1.0)
        scatter = ax.scatter([], [], color=color, s=42, edgecolor="none")
        legend_handles.append(
            Line2D(
                [0],
                [0],
                marker="o",
                color="none",
                markerfacecolor=color,
                markersize=8,
                label=name,
            )
        )
        artists.append((line, scatter, steps_for_series, x_vals, y_vals))

    ax.legend(handles=legend_handles, frameon=False, title="Run")

    def update(frame_step: int):
        for line, scatter, steps_for_series, x_vals, y_vals in artists:
            count = bisect_right(steps_for_series, frame_step)
            if count:
                line.set_data(x_vals[:count], y_vals[:count])
                scatter.set_offsets(np.column_stack(
                    (x_vals[:count], y_vals[:count])))
            else:
                line.set_data([], [])
                scatter.set_offsets(np.empty((0, 2)))
        return []

    anim = animation.FuncAnimation(
        fig, update, frames=steps, interval=1000 / max(fps, 1), blit=False
    )

    output.parent.mkdir(parents=True, exist_ok=True)
    suffix = output.suffix.lower()
    if suffix == ".gif":
        writer = animation.PillowWriter(fps=fps)
    elif suffix == ".mp4":
        if not animation.writers.is_available("ffmpeg"):
            raise RuntimeError("ffmpeg is required to save MP4 animations.")
        writer = animation.FFMpegWriter(fps=fps)
    else:
        raise ValueError("Animation output must have a .gif or .mp4 suffix.")
    anim.save(output, writer=writer, dpi=200)
    plt.close(fig)


def main() -> None:
    args = parse_args()
    if args.output is None:
        default_name = (
            "eval_loss_vs_heldout_loss.gif"
            if args.animate
            else "eval_loss_vs_heldout_loss.png"
        )
        args.output = args.sweep_dir / default_name
    if args.csv is None:
        args.csv = args.sweep_dir / "eval_loss_vs_heldout_loss.csv"
    if not args.include_full:
        args.include_full = (args.sweep_dir / "full" / "metrics.json").exists()

    series_list = collect_series(args.sweep_dir, args.include_full)
    rows: List[Dict[str, float | int | str]] = []
    for name, series in series_list:
        for point in series:
            rows.append(
                {
                    "run": name,
                    "step": point["step"],
                    "eval_loss": point["eval_loss"],
                    "heldout_eval_loss": point["heldout_eval_loss"],
                }
            )
    rows.sort(key=lambda item: (item["run"], item["step"]))
    header = rows[0].keys() if rows else []
    with args.csv.open("w", encoding="utf-8") as handle:
        handle.write(",".join(header) + "\n")
        for row in rows:
            handle.write(
                f"{row['run']},{row['step']},{row['eval_loss']},{row['heldout_eval_loss']}\n"
            )
    print(f"Saved CSV to {args.csv}")

    if args.animate:
        if args.output.suffix.lower() not in {".gif", ".mp4"}:
            raise ValueError(
                "Animation output must use a .gif or .mp4 extension.")
        plot_series_animated(series_list, args.output, fps=args.fps)
    else:
        plot_series(series_list, args.output)
    print(f"Saved plot to {args.output}")


if __name__ == "__main__":
    main()
