"""Compare Mamba variants with variance bands from a JSONL training log.

Reads the `train_runs.jsonl` format produced by `python -m mamba_experiments.train`
and plots mean ± std curves across runs (typically across seeds).
"""

from __future__ import annotations

import argparse
import hashlib
import json
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable

import numpy as np


def _iter_jsonl(path: Path) -> Iterable[dict[str, Any]]:
    if not path.exists():
        return
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue


def _ensure_matplotlib():
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except Exception as e:  # pragma: no cover
        raise RuntimeError("matplotlib is required for plotting: `pip install matplotlib`") from e
    return plt


MODEL_DISPLAY_NAMES = {
    "mamba_simple": "Original Mamba",
    "mamba_SSD_diag_exp": "Diagonal SSD Mamba",
    # Back-compat for older logs.
    "mamba_gssm_diag_exp": "Diagonal SSD Mamba",
}


GROUP_KEYS = (
    "dataset",
    "seq_len",
    "vocab_size",
    "train_size",
    "val_size",
    "max_vocab",
    "train_max_samples",
    "val_max_samples",
    "batch_size",
    "epochs",
    "d_model",
    "n_state",
    "n_layers",
    "lr",
)


def _group_key(rec: dict[str, Any]) -> tuple:
    return tuple(rec.get(k) for k in GROUP_KEYS)


def _display_model(model: str) -> str:
    return MODEL_DISPLAY_NAMES.get(model, model)


def _sanitize_filename_part(value: object) -> str:
    s = str(value)
    s = s.strip().replace(" ", "-")
    s = re.sub(r"[^A-Za-z0-9_.-]+", "-", s)
    return re.sub(r"-{2,}", "-", s).strip("-") or "x"


def _short_hash(value: object) -> str:
    payload = repr(value).encode("utf-8")
    return hashlib.md5(payload).hexdigest()[:8]


def _config_suffix(meta: dict[str, Any]) -> str:
    parts: list[str] = []
    for k in ("seq_len", "d_model", "n_state", "n_layers", "epochs", "batch_size", "lr", "max_vocab"):
        v = meta.get(k)
        if v is None:
            continue
        parts.append(f"{k}{_sanitize_filename_part(v)}")
    return "_".join(parts) if parts else "config"


@dataclass(frozen=True)
class CurveBundle:
    epochs: list[int]
    mean: np.ndarray
    std: np.ndarray
    n_runs: int
    n_seeds: int


def _extract_curve(history: list[dict[str, Any]], metric: str) -> list[float]:
    vals: list[float] = []
    for row in history:
        if metric in row:
            vals.append(float(row[metric]))
    return vals


def _bundle_curves(records: list[dict[str, Any]], metric: str) -> CurveBundle | None:
    histories = [r.get("history") for r in records]
    histories = [h for h in histories if isinstance(h, list) and h]
    if not histories:
        return None

    curves = [_extract_curve(h, metric) for h in histories]
    curves = [c for c in curves if c]
    if not curves:
        return None

    T = int(min(len(c) for c in curves))
    mat = np.array([c[:T] for c in curves], dtype=np.float64)
    mean = mat.mean(axis=0)
    std = mat.std(axis=0, ddof=1) if mat.shape[0] > 1 else np.zeros(T, dtype=np.float64)
    epochs = list(range(1, T + 1))

    seeds = sorted({r.get("seed") for r in records if isinstance(r.get("seed"), int)})
    return CurveBundle(epochs=epochs, mean=mean, std=std, n_runs=len(curves), n_seeds=len(seeds))


def _plot_with_band(plt, x, mean, std, *, color: str, label: str) -> None:
    plt.plot(x, mean, linewidth=4, alpha=0.9, color=color, label=label)
    plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.18)


def _style_axes(plt, title: str, ylabel: str) -> None:
    plt.xlabel("Epoch", fontsize=24)
    plt.ylabel(ylabel, fontsize=24)
    plt.title(title, fontsize=20, fontweight="bold")
    plt.tick_params(axis="both", which="major", labelsize=20)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=14)
    plt.tight_layout()


def main() -> None:
    p = argparse.ArgumentParser(description="Plot Original Mamba vs SSD Mamba with variance bands.")
    p.add_argument(
        "--log_path",
        type=str,
        default="outputs/mamba_experiments/train_runs.jsonl",
        help="JSONL file produced by `python -m mamba_experiments.train`.",
    )
    p.add_argument(
        "--out_dir",
        type=str,
        default="outputs/mamba_experiments/plots_compare",
        help="Directory to save comparison plots.",
    )
    p.add_argument("--dataset", type=str, default=None, help="Filter by dataset name (e.g. wikitext2).")
    p.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=["mamba_simple", "mamba_SSD_diag_exp"],
        help="Models to compare on the same plot.",
    )
    p.add_argument(
        "--metric",
        type=str,
        default="val_loss",
        choices=["train_loss", "val_loss", "train_acc", "val_acc"],
        help="Which curve from `history` to plot.",
    )
    p.add_argument(
        "--require_all_models",
        action="store_true",
        help="If set, skip configs that don't contain all requested models.",
    )
    args = p.parse_args()

    log_path = Path(args.log_path)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    all_recs = list(_iter_jsonl(log_path))
    if args.dataset is not None:
        all_recs = [r for r in all_recs if r.get("dataset") == args.dataset]
    if not all_recs:
        raise SystemExit(f"No records found in {log_path} for the requested filters.")

    # Group by config (excluding model/seed), then compare models within each config.
    grouped: dict[tuple, list[dict[str, Any]]] = defaultdict(list)
    for r in all_recs:
        grouped[_group_key(r)].append(r)

    plt = _ensure_matplotlib()

    colors = {
        "mamba_simple": "darkgrey",
        "mamba_SSD_diag_exp": "darkred",
        "mamba_gssm_diag_exp": "darkred",
    }

    metric_titles = {
        "train_loss": "Training Loss",
        "val_loss": "Validation Loss",
        "train_acc": "Training Accuracy",
        "val_acc": "Validation Accuracy",
    }
    ylabel_map = {
        "train_loss": "Loss",
        "val_loss": "Loss",
        "train_acc": "Accuracy",
        "val_acc": "Accuracy",
    }
    title_metric = metric_titles.get(args.metric, "Validation Accuracy" if args.metric == "val_acc" else args.metric)
    ylabel = ylabel_map.get(args.metric, "Accuracy" if args.metric.endswith("acc") else "Value")

    plotted = 0
    for key, recs in grouped.items():
        meta = {k: v for k, v in zip(GROUP_KEYS, key, strict=True)}
        dataset = str(meta.get("dataset") or "dataset")
        suffix = _config_suffix(meta)
        tag = _short_hash(key)

        by_model: dict[str, list[dict[str, Any]]] = defaultdict(list)
        for r in recs:
            m = r.get("model")
            if isinstance(m, str):
                by_model[m].append(r)

        present = set(by_model.keys())
        requested = list(args.models)
        if args.require_all_models and not all(m in present for m in requested):
            continue

        bundles: list[tuple[str, CurveBundle]] = []
        for m in requested:
            recs_m = by_model.get(m, [])
            bundle = _bundle_curves(recs_m, metric=args.metric)
            if bundle is None:
                continue
            bundles.append((m, bundle))

        if not bundles:
            continue

        plt.figure(figsize=(8, 6))
        for m, bundle in bundles:
            label = f"{_display_model(m)} (runs={bundle.n_runs})"
            _plot_with_band(
                plt,
                bundle.epochs,
                bundle.mean,
                bundle.std,
                color=colors.get(m, "black"),
                label=label,
            )

        title = f"{dataset}: {title_metric}"
        _style_axes(plt, title, ylabel)

        fname = f"{_sanitize_filename_part(dataset)}_{suffix}_{tag}_{args.metric}_compare.png"
        out_path = out_dir / fname
        plt.savefig(out_path, dpi=225, bbox_inches="tight")
        plt.close()
        print(f"Wrote {out_path}")
        plotted += 1

    if plotted == 0:
        raise SystemExit(
            "No plots produced. Try removing `--require_all_models`, "
            "or ensure both models were run with identical hyperparameters."
        )


if __name__ == "__main__":
    main()
