"""
Revision diffusion plots (error vs uniform steady state, runtime).

Reads a manifest produced by revision_diffusion.py.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

REPO_ROOT = Path(inject_repo_into_sys_path())
FIG_DIR_DEFAULT = REPO_ROOT / "figs"
ICLR_DIR_DEFAULT = REPO_ROOT / "iclr2026" / "figures"


def _load_manifest(path: Path) -> Dict:
    with path.open() as f:
        return json.load(f)


def _collect_by_particles(mg_runs: List[Dict]) -> Dict[int, Dict[str, np.ndarray]]:
    grouped: Dict[int, Dict[str, List[float]]] = {}
    for run in mg_runs:
        if run.get("status") != "ok":
            continue
        p = int(run["particles"])
        grouped.setdefault(p, {"err": [], "per_step": []})
        grouped[p]["err"].append(float(run["rel_l2_uniform"]))
        grouped[p]["per_step"].append(float(run.get("per_step_s", np.nan)))
    out: Dict[int, Dict[str, np.ndarray]] = {}
    for p, vals in grouped.items():
        out[p] = {
            "err_mean": np.nanmean(vals["err"]),
            "err_std": np.nanstd(vals["err"]),
            "rt_mean": np.nanmean(vals["per_step"]),
            "rt_std": np.nanstd(vals["per_step"]),
            "n": len(vals["err"]),
        }
    return out


def plot_error(grouped: Dict[int, Dict[str, np.ndarray]], fvm_err: float, out_paths: List[Path]) -> None:
    xs = np.array(sorted(grouped.keys()), dtype=float)
    err_mean = np.array([grouped[x]["err_mean"] for x in xs])
    err_std = np.array([grouped[x]["err_std"] for x in xs])

    fig, ax = plt.subplots(figsize=(6.5, 4.2))
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.errorbar(xs, err_mean, yerr=err_std, fmt="o-", lw=2, capsize=4, color="#2ca02c", label="Algorithm 1")
    ax.axhline(fvm_err, color="#d62728", ls="--", lw=2, label="FVM")
    ax.set_xlim(xs.min() * 0.7, xs.max() * 1.3)
    ax.set_xlabel("Particles")
    ax.set_ylabel("Relative L2 error")
    ax.set_title("Steady-state error, diffusion-only")
    ax.grid(True, which="both", alpha=0.3, linestyle="--")
    ax.legend()
    fig.tight_layout()
    for p in out_paths:
        p.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(p, dpi=300)
        print(f"[revision diffusion] wrote {p}")
    plt.close(fig)


def plot_runtime(grouped: Dict[int, Dict[str, np.ndarray]], fvm_per_step: float, out_paths: List[Path]) -> None:
    xs = sorted(grouped.keys())
    rt = np.array([grouped[x]["rt_mean"] for x in xs])
    rt_std = np.array([grouped[x]["rt_std"] for x in xs])
    norm = rt / fvm_per_step if fvm_per_step > 0 else rt
    norm_std = rt_std / fvm_per_step if fvm_per_step > 0 else rt_std
    speedup = fvm_per_step / rt if fvm_per_step > 0 else np.ones_like(rt)

    fig, ax = plt.subplots(figsize=(6.5, 4.2))
    bars = ax.bar(np.arange(len(xs)), norm, yerr=norm_std, capsize=4, color="#2ca02c", label="Algorithm 1")
    for i, val in enumerate(norm):
        ax.text(i, val + 0.02, f"{speedup[i]:.1f}x", ha="center", va="bottom", fontsize=8)
    ax.axhline(1.0, color="#d62728", ls="--", lw=2, label="FVM = 1")
    ax.set_xticks(np.arange(len(xs)))
    ax.set_xticklabels([f"{x//1000}k" if x >= 1000 else str(x) for x in xs])
    ax.set_ylabel("Normalized runtime per step")
    ax.set_xlabel("Particles")
    ax.set_title("Normalized runtimes, diffusion-only")
    ax.legend()
    ax.grid(True, axis="y", alpha=0.3, linestyle="--")
    fig.tight_layout()
    for p in out_paths:
        p.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(p, dpi=300)
        print(f"[revision diffusion] wrote {p}")
    plt.close(fig)


def main() -> None:
    import argparse

    parser = argparse.ArgumentParser(description="Revision diffusion plots from manifest.")
    parser.add_argument("--manifest", type=Path, required=True, help="Path to manifest.json from revision_diffusion.py")
    parser.add_argument("--fig-dir", type=Path, default=FIG_DIR_DEFAULT, help="Directory for figs outputs")
    parser.add_argument("--iclr-dir", type=Path, default=ICLR_DIR_DEFAULT, help="Directory for iclr figures")
    args = parser.parse_args()

    manifest = _load_manifest(args.manifest)
    fvm = manifest["fvm"]
    grouped = _collect_by_particles(manifest["mg_runs"])

    fig1_paths = [
        args.fig_dir / "vascular_diffusion_error.pdf",
        args.iclr_dir / "vascular_diffusion_error.pdf",
    ]
    fig2_paths = [
        args.fig_dir / "vascular_diffusion_runtime.pdf",
        args.iclr_dir / "vascular_diffusion_runtime.pdf",
    ]

    plot_error(grouped, fvm_err=float(fvm["rel_l2_uniform"]), out_paths=fig1_paths)
    plot_runtime(grouped, fvm_per_step=float(fvm["per_step_s"]), out_paths=fig2_paths)


if __name__ == "__main__":
    main()
