"""Convenience script to plot variance bands for Mamba vs SSD Mamba runs.

Reads a JSONL log produced by `python -m mamba_experiments.train` and overlays
multiple models (e.g., `mamba_simple`, `mamba_SSD_diag_exp`) on the same plot
with mean ± std bands across runs/seeds.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Any

# Ensure repository root is on sys.path when executed as a script.
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from mamba_experiments import plot_compare


def _run_compare(
    log_path: Path,
    out_dir: Path,
    dataset: str | None,
    models: list[str],
    metric: str,
    require_all_models: bool,
) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)

    all_recs = list(plot_compare._iter_jsonl(log_path))
    if dataset is not None:
        all_recs = [r for r in all_recs if r.get("dataset") == dataset]
    if not all_recs:
        raise SystemExit(f"No records found in {log_path} for the requested filters.")

    # Group by config (excluding model/seed), then compare models within each config.
    grouped: dict[tuple, list[dict[str, Any]]] = {}
    for r in all_recs:
        grouped.setdefault(plot_compare._group_key(r), []).append(r)

    plt = plot_compare._ensure_matplotlib()

    colors = {
        "mamba_simple": "darkgrey",
        "mamba_SSD_diag_exp": "darkred",
        "mamba_gssm_diag_exp": "darkred",
    }

    metric_titles = {
        "train_loss": "Training Loss",
        "val_loss": "Validation Loss",
        "train_acc": "Training Accuracy",
        "val_acc": "Validation Accuracy",
    }
    ylabel_map = {
        "train_loss": "Loss",
        "val_loss": "Loss",
        "train_acc": "Accuracy",
        "val_acc": "Accuracy",
    }
    title_metric = metric_titles.get(metric, "Validation Accuracy" if metric == "val_acc" else metric)
    ylabel = ylabel_map.get(metric, "Accuracy" if metric.endswith("acc") else "Value")

    plotted = 0
    for key, recs in grouped.items():
        meta = {k: v for k, v in zip(plot_compare.GROUP_KEYS, key, strict=True)}
        dataset_name = str(meta.get("dataset") or "dataset")
        suffix = plot_compare._config_suffix(meta)
        tag = plot_compare._short_hash(key)

        by_model: dict[str, list[dict[str, Any]]] = {}
        for r in recs:
            m = r.get("model")
            if isinstance(m, str):
                by_model.setdefault(m, []).append(r)

        present = set(by_model.keys())
        if require_all_models and not all(m in present for m in models):
            continue

        bundles: list[tuple[str, plot_compare.CurveBundle]] = []
        for m in models:
            recs_m = by_model.get(m, [])
            bundle = plot_compare._bundle_curves(recs_m, metric=metric)
            if bundle is None:
                continue
            bundles.append((m, bundle))

        if not bundles:
            continue

        plt.figure(figsize=(8, 6))
        for m, bundle in bundles:
            label = f"{plot_compare._display_model(m)} (runs={bundle.n_runs})"
            plot_compare._plot_with_band(
                plt,
                bundle.epochs,
                bundle.mean,
                bundle.std,
                color=colors.get(m, "black"),
                label=label,
            )

        title = f"{dataset_name}: {title_metric}"
        plot_compare._style_axes(plt, title, ylabel)

        fname = (
            f"{plot_compare._sanitize_filename_part(dataset_name)}_"
            f"{suffix}_{tag}_{metric}_compare.png"
        )
        out_path = out_dir / fname
        plt.savefig(out_path, dpi=225, bbox_inches="tight")
        plt.close()
        print(f"Wrote {out_path}")
        plotted += 1

    if plotted == 0:
        raise SystemExit(
            "No plots produced. Try removing `--require-all-models`, "
            "or ensure both models were run with identical hyperparameters."
        )


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Plot variance bands for Mamba vs SSD Mamba runs.")
    p.add_argument(
        "--log-path",
        type=Path,
        default=Path("outputs/mamba_experiments/train_runs.jsonl"),
        help="Path to JSONL produced by mamba_experiments.train.",
    )
    p.add_argument(
        "--out-dir",
        type=Path,
        default=Path("outputs/mamba_experiments/plots_compare"),
        help="Directory to save comparison plots.",
    )
    p.add_argument("--dataset", type=str, default=None, help="Optional dataset filter (e.g. wikitext2).")
    p.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=["mamba_simple", "mamba_SSD_diag_exp"],
        help="Models to compare on the same plot.",
    )
    p.add_argument(
        "--metric",
        type=str,
        default="val_loss",
        choices=["train_loss", "val_loss", "train_acc", "val_acc"],
        help="Which curve from `history` to plot.",
    )
    p.add_argument(
        "--require-all-models",
        action="store_true",
        help="If set, skip configs that don't contain all requested models.",
    )
    return p.parse_args()


def main() -> None:
    args = parse_args()
    _run_compare(
        log_path=args.log_path,
        out_dir=args.out_dir,
        dataset=args.dataset,
        models=args.models,
        metric=args.metric,
        require_all_models=args.require_all_models,
    )


if __name__ == "__main__":
    main()

