"""
Revision drift plots:
- Runtime vs dt (normalized, FVM=1 where stable)
- Steady-state density layout for the drift case
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
FIGS_DIR = REPO_ROOT / "figs"
ICLR_DIR = REPO_ROOT / "iclr2026" / "figures"


def _load_manifest(path: Path) -> Tuple[Dict[str, Any], Path]:
    data = json.loads(path.read_text())
    return data, path.parent


def _runtime_data(
    runs: List[Dict[str, Any]]
) -> Tuple[List[float], List[float], List[float], List[str], List[bool]]:
    dts: List[float] = []
    mg_rt: List[float] = []
    fvm_rt: List[float] = []
    labels: List[str] = []
    fvm_stable: List[bool] = []

    for run in sorted(runs, key=lambda r: r["dt"]):
        dt = float(run["dt"])
        labels.append(f"{dt:g}")
        dts.append(dt)
        fvm = run.get("fvm", {})
        mg = run.get("mg", {})
        adaptive = bool(fvm.get("adaptive_time_step")) or bool(fvm.get("log_contains_retry"))
        exec_status = fvm.get("execution_status", fvm.get("status"))
        is_stable = (exec_status in ("ok", "timeout")) and not adaptive
        fvm_stable.append(is_stable)
        mg_rt.append(float(mg["per_step_s"]) if mg.get("per_step_s") is not None else np.nan)
        fvm_rt.append(float(fvm["per_step_s"]) if fvm.get("per_step_s") is not None else np.nan)

    return dts, mg_rt, fvm_rt, labels, fvm_stable


def plot_runtime(runs: List[Dict[str, Any]], out_paths: List[Path]) -> None:
    dts, mg_rt, fvm_rt, labels, fvm_stable = _runtime_data(runs)
    x = np.arange(len(dts))
    width = 0.35

    fig, ax = plt.subplots(figsize=(7, 4.2))
    max_height = np.nanmax(
        np.concatenate([np.asarray(mg_rt, dtype=float), np.asarray(fvm_rt, dtype=float)])
    ) if len(dts) else 1.0
    default_y = 0.1 * max_height if np.isfinite(max_height) and max_height > 0 else 0.1
    unstable_x_added = False

    # FVM bars (left)
    for i, (fvm, stable) in enumerate(zip(fvm_rt, fvm_stable)):
        xpos = x[i] - width / 2
        if stable and not np.isnan(fvm):
            ax.bar([xpos], [fvm], width, color="#d62728", label="DuMuX FVM" if i == 0 else "_nolegend_")
        else:
            ax.scatter(
                xpos,
                1e-2,
                marker="x",
                color="red",
                s=100,
                label="FVM unstable" if not unstable_x_added else "_nolegend_",
            )
            unstable_x_added = True

    # MG bars (right)
    mg_bar = ax.bar(x + width / 2, mg_rt, width, label="$\\bf{Our}$ $\\bf{algorithm}$", color="#2ca02c")
    for i, (mg, fvm, stable) in enumerate(zip(mg_rt, fvm_rt, fvm_stable)):
        if stable and not np.isnan(fvm) and not np.isnan(mg) and mg > 0:
            sp = fvm / mg
            ax.text(
                x[i] + width / 2,
                mg * 1.01,
                f"{sp:.1f}x",
                ha="center",
                va="bottom",
                fontsize=12,
                fontweight="bold",
                color="black",
            )
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=12)
    ax.set_ylabel("Runtime per step (s)", fontsize=13)
    ax.set_xlabel(r"$\Delta t$ (s)", fontsize=13)
    ax.set_title("Runtimes, DuMuX network, 1M particles", fontsize=14)
    ax.tick_params(axis="y", labelsize=12)
    ax.set_yscale("log")
    ax.legend(fontsize=11)
    ax.grid(True, axis="y", linestyle="--", alpha=0.4)
    fig.tight_layout()
    for out in out_paths:
        out.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(out)
    plt.close(fig)
    print(f"[revision drift plots] runtime -> {[str(p) for p in out_paths]}")


def _load_density(run: Dict[str, Any], base_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
    """Return (tracer_final, times) using MG (preferred for consistency)."""
    mg = run.get("mg", {})
    data = np.load(base_dir / mg["npz"])
    density = np.asarray(data["density"], dtype=float)
    times = np.asarray(data["times"], dtype=float)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
    area = np.pi * cell_radii * cell_radii if cell_radii.size == cell_lengths.size else np.ones_like(cell_lengths)
    molar_density = 1000.0 / 0.018
    tracer = density * (1.0 / (molar_density * np.maximum(area, 1e-20)))
    return tracer[-1], times


def plot_density(
    manifest: Dict[str, Any],
    base_dir: Path,
    out_paths: List[Path],
    density_override: Path | None = None,
) -> None:
    if density_override:
        data = np.load(density_override)
        density = np.asarray(data["density"], dtype=float)
        tracer_final = density[-1]
        times = np.asarray(data["times"], dtype=float)
    else:
        runs = sorted(manifest["runs"], key=lambda r: r["dt"])
        chosen = None
        for r in runs:
            if r["mg"].get("status") == "ok":
                chosen = r
                break
        if chosen is None:
            raise RuntimeError("No successful runs found for density plot")
        tracer_final, times = _load_density(chosen, base_dir)

    input_npz = base_dir / manifest["preprocess"]["input_npz"]
    input_data = np.load(input_npz)
    cell_edges = np.asarray(input_data["cell_edges"], dtype=int)
    cell_points = np.asarray(input_data["cell_points"], dtype=float)
    orig_points = np.asarray(input_data["orig_points"], dtype=float)

    xy = cell_points[:, :2]
    segments = []
    colors = []
    p99 = np.percentile(tracer_final, 99) if tracer_final.size else 1.0
    scale = p99 if p99 > 0 else 1.0
    colors_raw = tracer_final / scale
    for (u, v), val in zip(cell_edges, colors_raw):
        segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
        colors.append(val)

    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=1.0)
    lc = matplotlib.collections.LineCollection(
        segments,
        cmap="inferno",
        norm=norm,
        linewidths=1.5,
    )
    lc.set_array(np.asarray(colors))

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.add_collection(lc)
    ax.scatter(orig_points[:, 0], orig_points[:, 1], s=5, c="tab:blue", alpha=0.8, label="nodes")
    ax.autoscale()
    ax.set_aspect("equal")
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("tracer density (normalized)")
    ax.set_title("Tracer steady state density (with DuMuX flows)", fontsize=14)
    fig.tight_layout()
    for out in out_paths:
        out.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(out)
    plt.close(fig)
    print(f"[revision drift plots] density -> {[str(p) for p in out_paths]}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Revision drift runtime/density plots")
    parser.add_argument("--manifest", type=Path, required=True, help="Path to drift manifest.json")
    parser.add_argument("--density-npz", type=Path, default=None, help="Optional MG NPZ for density plot")
    args = parser.parse_args()

    manifest, base_dir = _load_manifest(args.manifest)

    # Runtime plot
    plot_runtime(
        manifest["runs"],
        out_paths=[
            FIGS_DIR / "vascular_drift_runtime.pdf",
            ICLR_DIR / "vascular_drift_runtime.pdf",
        ],
    )

    # Density plot
    plot_density(
        manifest,
        base_dir=base_dir,
        density_override=args.density_npz,
        out_paths=[
            FIGS_DIR / "vascular_drift_density.pdf",
            ICLR_DIR / "vascular_drift_density.pdf",
        ],
    )


if __name__ == "__main__":
    main()
