"""
Summary plots comparing diffusion vs advection runs:
- grouped bar chart of FVM vs MG wall times (with speedup annotations)
- bar chart of median relative L2 mass error
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.compare_metric_to_dumux import (
    _cell_mass,
    load_metric,
    load_dumux,
    match_times,
)


@dataclass(frozen=True)
class Scenario:
    name: str
    dumux_npz: Path
    metric_npz: Path
    dumux_runtime: Path
    metric_runtime: Optional[Path]
    flux_scale: float | None = None


def _load_runtime(path: Path) -> float | None:
    if not path.exists():
        return None
    try:
        return float(path.read_text().strip().split()[0])
    except Exception:
        return None


def _median_rel_error(dumux_npz: Path, metric_npz: Path) -> float:
    dumux_times, dumux_mass, dumux_total, cell_edges, cell_points, cell_lengths, cell_radii = load_dumux(dumux_npz)
    metric_times, metric_mass, _, _ = load_metric(
        metric_npz,
        cell_edges=cell_edges,
        cell_points=cell_points,
        cell_lengths=cell_lengths,
        cell_radii=cell_radii,
        target_total=float(dumux_total[0] if dumux_total.size else 1.0),
    )
    pairs = match_times(metric_times, dumux_times, tol=1e-6)
    if not pairs:
        return float("nan")
    rel = []
    for m_idx, d_idx in pairs:
        err = np.linalg.norm(metric_mass[m_idx] - dumux_mass[d_idx])
        tot = max(dumux_total[d_idx], 1e-20)
        rel.append(err / tot)
    return float(np.median(rel)) if rel else float("nan")


def _edge_errors_final(dumux_npz: Path, metric_npz: Path) -> np.ndarray:
    dumux = np.load(dumux_npz)
    dumux_times, dumux_mass, dumux_total, cell_edges, cell_points, cell_lengths, cell_radii = load_dumux(dumux_npz)
    metric_times, metric_mass, _, _ = load_metric(
        metric_npz,
        cell_edges=cell_edges,
        cell_points=cell_points,
        cell_lengths=cell_lengths,
        cell_radii=cell_radii,
        target_total=float(dumux_total[0] if dumux_total.size else 1.0),
    )
    cell_to_edge = np.asarray(dumux["cell_to_edge"], dtype=int)
    # match final dumux time to nearest metric snapshot
    t_final = dumux_times[-1]
    m_idx = int(np.argmin(np.abs(metric_times - t_final)))
    d_idx = len(dumux_times) - 1
    du_final = dumux_mass[d_idx]
    mg_final = metric_mass[m_idx]
    num_edges = int(np.max(cell_to_edge)) + 1
    du_edge = np.bincount(cell_to_edge, weights=du_final, minlength=num_edges)
    mg_edge = np.bincount(cell_to_edge, weights=mg_final, minlength=num_edges)
    eps = 1e-20
    rel = np.abs(du_edge - mg_edge) / np.maximum(np.abs(du_edge), eps)
    return rel


def main() -> None:
    repo = Path(__file__).resolve().parents[2]
    scenarios = [
        Scenario(
            "diffusion",
            dumux_npz=repo / "data/dumux_network_tracer_1d.npz",
            metric_npz=repo / "data/dumux_metric_graph_sim.npz",
            dumux_runtime=repo / "data/dumux_network_tracer_1d_runtime.txt",
            metric_runtime=None,  # use NPZ runtime
            flux_scale=0.0,
        ),
        Scenario(
            "advection",
            dumux_npz=repo / "data/dumux_advection_tracer_1d.npz",
            metric_npz=repo / "data/dumux_advection_metric_graph_sim.npz",
            dumux_runtime=repo / "data/dumux_advection_tracer_1d_runtime.txt",
            metric_runtime=None,
            flux_scale=None,  # loaded below from NPZ
        ),
    ]

    labels = []
    fvm_times = []
    mg_times = []
    rel_errors = []
    rows = []
    for sc in scenarios:
        # auto-pop flux if not provided
        flux = sc.flux_scale
        if flux is None:
            try:
                flux = float(np.load(sc.metric_npz)["flux_scale"])
            except Exception:
                flux = float("nan")
        dumux_data = np.load(sc.dumux_npz)
        cell_lengths = np.asarray(dumux_data["cell_lengths"], dtype=float)
        num_cells = cell_lengths.size
        dx_med = float(np.median(cell_lengths)) if cell_lengths.size else float("nan")
        metric_data = np.load(sc.metric_npz)
        particles = int(metric_data["num_particles"]) if "num_particles" in metric_data else None
        labels.append(sc.name)
        fvm = _load_runtime(sc.dumux_runtime)
        mg = (
            float(metric_data["total_wall_time"])
            if "total_wall_time" in metric_data
            else _load_runtime(sc.metric_runtime) if sc.metric_runtime else None
        )
        fvm_times.append(fvm)
        mg_times.append(mg)
        med_rel = _median_rel_error(sc.dumux_npz, sc.metric_npz)
        rel_errors.append(med_rel)
        speedup = fvm / mg if (fvm is not None and mg is not None and mg > 0) else float("nan")
        rows.append(
            dict(
                name=sc.name,
                flux=flux,
                fvm=fvm,
                mg=mg,
                speedup=speedup,
                rel=med_rel,
                dx_med=dx_med,
                num_cells=num_cells,
                particles=particles,
                edge_rel=_edge_errors_final(sc.dumux_npz, sc.metric_npz),
            )
        )

    # Runtime grouped bar
    x = np.arange(len(labels))
    width = 0.35
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(x - width / 2, fvm_times, width, label="FVM (DuMuX)", color="C1")
    ax.bar(x + width / 2, mg_times, width, label="CUDA MG", color="C0")
    ax.set_ylabel("Wall time (s)")
    ax.set_xticks(x, labels)
    ax.legend()
    for xi, fvm, mg in zip(x, fvm_times, mg_times):
        if fvm is not None:
            ax.text(xi - width / 2, fvm, f"{fvm:.1f}s", ha="center", va="bottom", fontsize=8)
        if mg is not None:
            ax.text(xi + width / 2, mg, f"{mg:.1f}s", ha="center", va="bottom", fontsize=8)
        if fvm is not None and mg is not None and mg > 0:
            ax.text(xi, max(fvm, mg) * 1.02, f"{fvm / mg:.2f}×", ha="center", va="bottom", fontsize=9)
    fig.tight_layout()
    out_dir = repo / "figs" / "summary"
    out_dir.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_dir / "runtime_comparison.png", dpi=200)
    plt.close(fig)

    # Relative error bar (log scale)
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.bar(labels, rel_errors, color="C2")
    ax.set_ylabel("Median relative L2 (mass)")
    ax.set_yscale("log")
    for xi, v in zip(labels, rel_errors):
        ax.text(xi, v, f"{v:.2e}", ha="center", va="bottom", fontsize=8)
    fig.tight_layout()
    fig.savefig(out_dir / "rel_error_comparison.png", dpi=200)
    plt.close(fig)

    # Scatter: MG runtime vs median rel error (log axes) with speedup text.
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for row in rows:
        mg = row["mg"]
        rel = row["rel"]
        name = row["name"]
        speed = row["speedup"]
        if mg is None or rel is None or np.isnan(mg) or np.isnan(rel):
            continue
        ax.scatter(mg, rel, label=name)
        ax.text(mg, rel, f" {name}\\n{speed:.2f}×", fontsize=8)
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("MG wall time (s)")
    ax.set_ylabel("Median relative L2 (mass)")
    ax.set_title("Accuracy vs MG runtime")
    ax.grid(True, which="both", ls=":", alpha=0.4)
    fig.tight_layout()
    fig.savefig(out_dir / "runtime_vs_error.png", dpi=200)
    plt.close(fig)

    # Per-edge error CDF (final snapshot)
    fig, ax = plt.subplots(figsize=(5, 4))
    for row in rows:
        rel = np.sort(row["edge_rel"])
        n = rel.size
        y = np.linspace(0, 1, n, endpoint=False)
        ax.plot(rel, y, label=row["name"])
    ax.set_xscale("log")
    ax.set_xlabel("Per-edge relative mass error (final)")
    ax.set_ylabel("CDF")
    ax.set_title("Edge-wise error distribution (final snapshot)")
    ax.grid(True, which="both", ls=":", alpha=0.4)
    ax.legend()
    fig.tight_layout()
    fig.savefig(out_dir / "edge_error_cdf.png", dpi=200)
    plt.close(fig)

    # CSV summary
    csv_lines = [
        "scenario,flux_scale,fvm_wall,mg_wall,speedup,median_rel_l2,dx_median,num_cells,mg_particles"
    ]
    for row in rows:
        csv_lines.append(
            f"{row['name']},{row['flux']},{row['fvm']},{row['mg']},{row['speedup']},"
            f"{row['rel']},{row['dx_med']},{row['num_cells']},{row['particles']}"
        )
    (out_dir / "summary_metrics.csv").write_text("\n".join(csv_lines) + "\n")
    print(
        f"[summary] wrote {out_dir}/runtime_comparison.png, "
        f"{out_dir}/rel_error_comparison.png, {out_dir}/runtime_vs_error.png, "
        f"and {out_dir}/summary_metrics.csv"
    )


if __name__ == "__main__":
    main()
