"""Plot utilities for time-series experiment logs (JSONL)."""

from __future__ import annotations

import argparse
import json
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable
import numpy as np


def _iter_jsonl(path: Path) -> Iterable[dict[str, Any]]:
    if not path.exists():
        return
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue


def _get_model_kind(rec: dict[str, Any]) -> str:
    return str(rec.get("model_kind") or rec.get("block_kind") or "unknown")


def _get_d_state(rec: dict[str, Any]) -> int | None:
    d = rec.get("d_state")
    if d is None:
        return None
    try:
        return int(d)
    except Exception:
        return None


def _last_by_key(records: Iterable[dict[str, Any]], key_fn) -> list[dict[str, Any]]:
    last: dict[Any, dict[str, Any]] = {}
    for rec in records:
        last[key_fn(rec)] = rec
    return list(last.values())


def _ensure_matplotlib():
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except Exception as e:  # pragma: no cover
        raise RuntimeError("matplotlib is required for plotting: `pip install matplotlib`") from e
    return plt


def _display_model_name(model_kind: str) -> str:
    mapping = {
        "mamba_diag_exp": "Diagonal SSD",
        "mamba_simple": "Mamba Simple (extra)",
    }
    return mapping.get(model_kind, model_kind)


def _group_key(rec: dict[str, Any]) -> tuple:
    def _epochs_from_rec(r: dict[str, Any]) -> int | None:
        val = r.get("epochs")
        if val is None and isinstance(r.get("train_config"), dict):
            val = r["train_config"].get("epochs")
        try:
            return int(val) if val is not None else None
        except Exception:
            return None

    return (
        rec.get("task"),
        _get_model_kind(rec),
        _get_d_state(rec),
        rec.get("d_model"),
        rec.get("n_layers"),
        _epochs_from_rec(rec),
    )


def _plot_summary_vs_d_state(
    records: list[dict[str, Any]],
    out_path: Path,
    metric: str,
    title: str,
    ylabel: str,
    with_variance: bool,
    logy: bool = False,
) -> None:
    plt = _ensure_matplotlib()

    by_model: dict[str, list[tuple[int, float]]] = defaultdict(list)
    by_model_std: dict[str, list[tuple[int, float]]] = defaultdict(list)

    if with_variance:
        grouped: dict[tuple, list[dict[str, Any]]] = defaultdict(list)
        for rec in records:
            grouped[_group_key(rec)].append(rec)
        for key, group_recs in grouped.items():
            model_kind = key[1]
            d_state = key[2]
            if d_state is None:
                continue
            ys = [r.get(metric) for r in group_recs if r.get(metric) is not None]
            if not ys:
                continue
            arr = np.array(ys, dtype=np.float64)
            by_model[model_kind].append((int(d_state), float(arr.mean())))
            by_model_std[model_kind].append((int(d_state), float(arr.std(ddof=1)) if arr.size > 1 else 0.0))
    else:
        for rec in records:
            d_state = _get_d_state(rec)
            if d_state is None:
                continue
            y = rec.get(metric)
            if y is None:
                continue
            by_model[_get_model_kind(rec)].append((d_state, float(y)))

    if not by_model:
        raise RuntimeError(f"No records with metric={metric} and d_state found.")

    plt.figure(figsize=(8, 6))
    color_map = {
        "mamba_diag_exp": "darkred",
        "mamba_simple": "black",
    }
    marker_map = {
        "mamba_diag_exp": "o",
        "mamba_simple": "^",
    }
    ordered_model_kinds = [k for k in ("mamba_diag_exp", "mamba_simple") if k in by_model]
    ordered_model_kinds += [k for k in sorted(by_model.keys()) if k not in ordered_model_kinds]

    for model_kind in ordered_model_kinds:
        pts = by_model[model_kind]
        pts_sorted = sorted(pts, key=lambda p: p[0])
        xs = [p[0] for p in pts_sorted]
        ys = [p[1] for p in pts_sorted]
        if with_variance and model_kind in by_model_std:
            std_pts = sorted(by_model_std[model_kind], key=lambda p: p[0])
            stds = [p[1] for p in std_pts]
            ys_arr = np.asarray(ys, dtype=np.float64)
            std_arr = np.asarray(stds, dtype=np.float64)
            lower = np.maximum(ys_arr - std_arr, 1e-12)
            upper = ys_arr + std_arr
            plt.plot(
                xs,
                ys,
                marker=marker_map.get(model_kind, "o"),
                linewidth=4,
                alpha=0.9,
                color=color_map.get(model_kind, None),
                label=f"{_display_model_name(model_kind)} (mean ± std)",
            )
            plt.fill_between(
                xs,
                lower,
                upper,
                color=color_map.get(model_kind, None),
                alpha=0.18,
            )
        else:
            plt.plot(
                xs,
                ys,
                marker=marker_map.get(model_kind, "o"),
                linewidth=4,
                alpha=0.9,
                color=color_map.get(model_kind, None),
                label=_display_model_name(model_kind),
            )

    plt.title(title, fontsize=20, fontweight="bold")
    plt.xlabel("D_state (N)", fontsize=24)
    plt.ylabel(ylabel, fontsize=24)
    if logy:
        plt.yscale("log")
        plt.ylabel(f"Log {ylabel}", fontsize=24)
    plt.grid(True, alpha=0.3)
    plt.tick_params(axis="both", which="major", labelsize=20)
    plt.legend(fontsize=16, title="Model", title_fontsize=16)
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=225, bbox_inches="tight")
    plt.close()


def _plot_learning_curves(
    records: list[dict[str, Any]],
    out_path: Path,
    split: str,
    loss_key: str,
    title: str,
    logy: bool = False,
    with_variance: bool = False,
) -> None:
    plt = _ensure_matplotlib()

    plt.figure(figsize=(10, 8))
    plotted = 0
    cmap_map = {
        "mamba_diag_exp": "Reds",
        "mamba_simple": "Blues",
    }
    base_color_map = {
        "mamba_diag_exp": "darkred",
        "mamba_simple": "black",
    }

    d_states_by_model: dict[str, list[int]] = defaultdict(list)
    for rec in records:
        model_kind = _get_model_kind(rec)
        d_state = _get_d_state(rec)
        if d_state is None:
            continue
        d_states_by_model[model_kind].append(d_state)
    unique_sorted_d_states_by_model = {k: sorted(set(v)) for k, v in d_states_by_model.items()}

    def _line_color(model_kind: str, d_state: int | None):
        if d_state is None:
            return base_color_map.get(model_kind, None)
        d_states = unique_sorted_d_states_by_model.get(model_kind, [])
        if len(d_states) <= 1:
            return base_color_map.get(model_kind, None)
        idx = d_states.index(d_state)
        denom = max(len(d_states) - 1, 1)
        frac = 0.35 + 0.55 * (idx / denom)  # avoid too-light/too-dark extremes
        cmap_name = cmap_map.get(model_kind)
        if cmap_name is None:
            return base_color_map.get(model_kind, None)
        return plt.get_cmap(cmap_name)(frac)

    if with_variance:
        grouped: dict[tuple, list[list[float]]] = defaultdict(list)
        for rec in records:
            history = rec.get("history") or []
            if not isinstance(history, list) or not history:
                continue
            ys_raw = [float(h.get(loss_key)) for h in history if loss_key in h]
            if not ys_raw:
                continue
            grouped[_group_key(rec)].append(ys_raw)

        for key, curves in grouped.items():
            model_kind = key[1]
            d_state = key[2]
            if not curves:
                continue
            T = min(len(c) for c in curves)
            mat = np.array([c[:T] for c in curves], dtype=np.float64)
            if logy:
                mat = np.maximum(mat, 1e-12)
            mean = mat.mean(axis=0)
            std = mat.std(axis=0, ddof=1) if mat.shape[0] > 1 else np.zeros(T, dtype=np.float64)
            xs = list(range(1, T + 1))
            label = f"{_display_model_name(model_kind)} N={d_state}" if d_state is not None else _display_model_name(model_kind)
            color = _line_color(model_kind, d_state)
            plt.plot(xs, mean, linewidth=4, alpha=0.9, color=color, label=label)
            plt.fill_between(xs, mean - std, mean + std, color=color, alpha=0.18)
            plotted += 1
    else:
        for rec in sorted(records, key=lambda r: (_get_model_kind(r), _get_d_state(r) or -1)):
            history = rec.get("history") or []
            if not isinstance(history, list) or not history:
                continue
            ys_raw = [float(h.get(loss_key)) for h in history if loss_key in h]
            if logy:
                eps = 1e-12
                ys = [max(v, eps) for v in ys_raw]
            else:
                ys = ys_raw
            if not ys:
                continue
            xs = list(range(1, len(ys) + 1))
            model_kind = _get_model_kind(rec)
            d_state = _get_d_state(rec)
            label = f"{_display_model_name(model_kind)} N={d_state}" if d_state is not None else _display_model_name(model_kind)
            plt.plot(
                xs,
                ys,
                linewidth=4,
                alpha=0.9,
                color=_line_color(model_kind, d_state),
                label=label,
            )
            plotted += 1

    if plotted == 0:
        raise RuntimeError(f"No learning curves found for key={loss_key}.")

    plt.title(title, fontsize=20, fontweight="bold")
    plt.xlabel("Epoch", fontsize=24)
    split_label = split.strip().title() if split else "Val"
    y_label = f"{split_label} Loss (MSE)"
    if logy:
        y_label = f"Log {y_label}"
    plt.ylabel(y_label, fontsize=24)
    if logy:
        plt.yscale("log")
    plt.grid(True, alpha=0.3)
    plt.tick_params(axis="both", which="major", labelsize=20)
    plt.legend(fontsize=14, title="Run", title_fontsize=14)
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=225, bbox_inches="tight")
    plt.close()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--log_path",
        type=str,
        default="outputs/time_series_experiments/ts_runs.jsonl",
        help="Path to JSONL produced by the experiments.",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        default="outputs/time_series_experiments",
        help="Directory to write PNG plots into.",
    )
    parser.add_argument(
        "--task",
        type=str,
        default="synthetic_decays",
        choices=["synthetic_decays"],
    )
    parser.add_argument(
        "--d_states",
        type=int,
        nargs="+",
        default=None,
        help="Optional filter for which d_state (N) values to include (e.g. `--d_states 1 2`).",
    )
    parser.add_argument(
        "--model_kinds",
        type=str,
        nargs="+",
        default=None,
        help="Optional filter for which model_kind values to include (e.g. `--model_kinds mamba_diag_exp`).",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=None,
        help="Optional filter: keep only runs with this exact epochs value.",
    )
    parser.add_argument(
        "--strict_hparams",
        action="store_true",
        help="If set, only aggregate/compare runs with identical hyperparameters (task, model_kind, d_state, d_model, n_layers, epochs).",
    )
    parser.add_argument(
        "--log_val_curves",
        action="store_true",
        help="Plot validation curves with a log y-axis (recommended for MSE).",
    )
    parser.add_argument(
        "--latest_only",
        action="store_true",
        help="Keep only the last run per (task, model_kind, d_state, d_model, n_layers).",
    )
    parser.add_argument(
        "--variance_bands",
        action="store_true",
        help="Aggregate runs per config and plot mean ± std for best-val MSE vs N (learning curves remain per-run).",
    )
    parser.add_argument(
        "--log_best_mse",
        action="store_true",
        help="Plot best-val MSE vs N with a log y-axis.",
    )
    args = parser.parse_args()

    log_path = Path(args.log_path)
    out_dir = Path(args.out_dir)

    records_all = [r for r in _iter_jsonl(log_path) if r.get("task") == args.task]
    if not records_all:
        raise SystemExit(f"No records for task={args.task} in {log_path}")

    records = records_all
    def _rec_epochs(r: dict[str, Any]) -> int | None:
        val = r.get("epochs")
        if val is None and isinstance(r.get("train_config"), dict):
            val = r["train_config"].get("epochs")
        try:
            return int(val) if val is not None else None
        except Exception:
            return None

    if args.epochs is not None:
        records = [r for r in records if _rec_epochs(r) == args.epochs]

    if args.latest_only:
        records = _last_by_key(records, _group_key if args.strict_hparams else lambda r: (
            r.get("task"),
            _get_model_kind(r),
            _get_d_state(r),
            r.get("d_model"),
            r.get("n_layers"),
        ))

    if args.model_kinds is not None:
        allowed = set(args.model_kinds)
        records = [r for r in records if _get_model_kind(r) in allowed]

    if args.d_states is not None:
        allowed = set(int(x) for x in args.d_states)
        records = [r for r in records if (_get_d_state(r) in allowed)]

    if not records:
        raise SystemExit("No records left after filtering; adjust `--d_states/--model_kinds` or logging file.")

    # Report mean final val MSE per N (d_state) over the filtered records.
    final_mse_by_N: dict[int, list[float]] = defaultdict(list)
    for rec in records:
        d_state = _get_d_state(rec)
        if d_state is None:
            continue
        val = rec.get("final_val_loss") or rec.get("val_loss")
        if val is None:
            continue
        try:
            final_mse_by_N[int(d_state)].append(float(val))
        except Exception:
            continue
    if final_mse_by_N:
        print("Mean final val MSE by N (filtered records):")
        for N in sorted(final_mse_by_N):
            arr = np.array(final_mse_by_N[N], dtype=np.float64)
            mean = arr.mean()
            std = arr.std(ddof=1) if arr.size > 1 else 0.0
            print(f"  N={N}: mean={mean:.6f} std={std:.6f} (runs={arr.size})")

    _plot_summary_vs_d_state(
        records,
        out_dir / "synthetic_best_val_mse_vs_N.png",
        metric="best_val_loss",
        title="Synthetic Mixture-of-Decays: Best Val MSE vs N",
        ylabel="Best Val MSE",
        with_variance=args.variance_bands,
        logy=args.log_best_mse,
    )
    _plot_learning_curves(
        records,
        out_dir / "synthetic_val_mse_curves.png",
        split="val",
        loss_key="val_loss",
        title="Synthetic Mixture-of-Decays: Log Val Loss Curves"
        if args.log_val_curves
        else "Synthetic Mixture-of-Decays: Val Loss Curves",
        logy=args.log_val_curves,
        with_variance=args.variance_bands,
    )

    print(f"Wrote plots to {out_dir}")


if __name__ == "__main__":
    main()
