"""
Generate plots comparing DuMuX tracer to the CUDA metric-graph run.

Outputs a single PNG with:
1) Edge-mass L2 error vs time (matched snapshots).
2) Runtime bars for DuMuX and metric run.
3) Network layout with edges colored by radius.
"""

from pathlib import Path
import subprocess
import time

import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.compare_metric_to_dumux import (
    config as compare_config,
    load_dumux,
    load_metric,
    match_times,
    _aggregate_to_metric,
)

REPO_ROOT = Path(inject_repo_into_sys_path())


def _map_coarse_to_refined(coarse_mass: np.ndarray, coarse_edges: np.ndarray, refined_edges: np.ndarray, refined_points: np.ndarray, coarse_points: np.ndarray) -> np.ndarray:
    """Map coarse edge mass onto refined edges using nearest coarse edge midpoint."""
    coarse_mid = (coarse_points[coarse_edges[:, 0]] + coarse_points[coarse_edges[:, 1]]) / 2.0
    refined_mid = (refined_points[refined_edges[:, 0]] + refined_points[refined_edges[:, 1]]) / 2.0
    dist2 = np.sum((refined_mid[:, None, :] - coarse_mid[None, :, :]) ** 2, axis=2)
    mapping = np.argmin(dist2, axis=1)
    mapped = np.zeros(refined_edges.shape[0], dtype=float)
    for ref_idx, coarse_idx in enumerate(mapping):
        mapped[ref_idx] = coarse_mass[coarse_idx]
    return mapped


def _run_dumux_example(example_path: Path, params_path: Path) -> float:
    start = time.perf_counter()
    subprocess.run(
        [str(example_path), str(params_path)],
        check=True,
        cwd=str(example_path.parent),
    )
    return time.perf_counter() - start


def main() -> None:
    figs_dir = REPO_ROOT / "figs"
    figs_dir.mkdir(exist_ok=True, parents=True)

    dumux_npz = compare_config.dumux_npz
    metric_npz = compare_config.metric_npz
    net_path = (
        REPO_ROOT
        / "external"
        / "dumux"
        / "examples"
        / "network_tracer_1d"
        / "network.dgf"
    )
    example_path = (
        REPO_ROOT
        / "external"
        / "dumux"
        / "build-cmake"
        / "examples"
        / "network_tracer_1d"
        / "example_network_tracer_1d"
    )
    params_path = example_path.parent / "params.input"

    dumux_times, dumux_mass, dumux_total, (dumux_edges, dumux_points), radius = load_dumux(
        dumux_npz
    )
    # Metric mass as probabilities (no area); scale by DuMuX totals downstream.
    metric_times, metric_mass, (metric_edges, metric_points) = load_metric(metric_npz)
    if dumux_mass.shape[1] != metric_mass.shape[1]:
        raise RuntimeError(
            f"Edge count mismatch (dumux {dumux_mass.shape[1]} vs metric {metric_mass.shape[1]}). "
            "Run with consistent refinement settings."
        )
    pairs = match_times(metric_times, dumux_times, tol=1.0)

    errors = []
    times_matched = []
    rel_errors = []
    for m_idx, d_idx in pairs:
        target_total = dumux_total[d_idx]
        err = np.linalg.norm(metric_mass[m_idx] * target_total - dumux_mass[d_idx], ord=2)
        errors.append(err)
        times_matched.append(metric_times[m_idx])
        rel_errors.append(err / max(target_total, 1e-20))
    errors = np.asarray(errors)
    rel_errors = np.asarray(rel_errors)

    metric_data = np.load(metric_npz)
    total_wall_metric = float(metric_data.get("total_wall_time", np.nan))
    num_particles = int(metric_data.get("num_particles", -1))
    num_bins = int(metric_data.get("num_bins", -1))
    dt = float(metric_data.get("dt", np.nan))
    steps_used = int(metric_data.get("steps", -1))
    velocity_scale = float(metric_data.get("velocity_scale", np.nan))
    record_interval = int(metric_data.get("record_interval", -1))
    time_horizon = float(metric_data.get("time_horizon", metric_times[-1]))

    dumux_runtime = _run_dumux_example(example_path, params_path)

    # Absolute error plot
    fig1, ax1 = plt.subplots(figsize=(6, 4))
    ax1.plot(times_matched, errors, marker="o")
    ax1.set_xlabel("Time (s)")
    ax1.set_ylabel("Edge-mass L2 error")
    ax1.set_title("DuMuX vs Metric (mass-scaled) absolute error")
    fig1.tight_layout()
    fig1.savefig(figs_dir / "dumux_metric_error_abs.png", dpi=200)
    plt.close(fig1)

    # Relative error plot
    fig2, ax2 = plt.subplots(figsize=(6, 4))
    ax2.plot(times_matched, rel_errors, marker="o", color="C1")
    ax2.set_xlabel("Time (s)")
    ax2.set_ylabel("Relative L2 error (per total mass)")
    ax2.set_title("DuMuX vs Metric relative error")
    fig2.tight_layout()
    fig2.savefig(figs_dir / "dumux_metric_error_rel.png", dpi=200)
    plt.close(fig2)

    # Runtime bar plot
    fig3, ax3 = plt.subplots(figsize=(4, 4))
    ax3.bar(["DuMuX", "Metric"], [dumux_runtime, total_wall_metric], color=["C1", "C0"])
    ax3.set_ylabel("Wall time (s)")
    ax3.set_title("Runtime to produce densities")
    fig3.tight_layout()
    fig3.savefig(figs_dir / "dumux_metric_runtime.png", dpi=200)
    plt.close(fig3)

    # Graph with density coloring (metric final snapshot mass)
    final_mass = metric_mass[pairs[-1][0]] if pairs else metric_mass[-1]
    mapped_mass = _map_coarse_to_refined(
        final_mass,
        metric_edges,
        dumux_edges,
        dumux_points,
        metric_points,
    )
    mass_norm = (mapped_mass - mapped_mass.min()) / (mapped_mass.max() - mapped_mass.min() + 1e-12)
    colors = plt.cm.plasma(mass_norm)
    coords = dumux_points
    edges = dumux_edges
    fig4, ax4 = plt.subplots(figsize=(6, 6))
    for idx, (u, v) in enumerate(edges):
        ax4.plot(
            [coords[u, 0], coords[v, 0]],
            [coords[u, 1], coords[v, 1]],
            color=colors[idx],
            linewidth=0.8,
            alpha=0.9,
        )
    ax4.scatter(coords[:, 0], coords[:, 1], s=0.5, color="k", alpha=0.3)
    ax4.set_aspect("equal", adjustable="datalim")
    ax4.set_title("Metric final density mapped to refined DuMuX graph")
    ax4.set_xlabel("x (m)")
    ax4.set_ylabel("y (m)")
    fig4.tight_layout()
    fig4.savefig(figs_dir / "dumux_metric_graph_density.png", dpi=200)
    plt.close(fig4)

    # DuMuX final density on refined graph
    dumux_final = dumux_mass[-1]
    if dumux_final.shape[0] != dumux_edges.shape[0]:
        dumux_final = _map_coarse_to_refined(
            dumux_final,
            metric_edges,
            dumux_edges,
            dumux_points,
            metric_points,
        )
    dumux_norm = (dumux_final - dumux_final.min()) / (dumux_final.max() - dumux_final.min() + 1e-12)
    colors_dumux = plt.cm.inferno(dumux_norm)
    fig6, ax6 = plt.subplots(figsize=(6, 6))
    for idx, (u, v) in enumerate(dumux_edges):
        ax6.plot(
            [dumux_points[u, 0], dumux_points[v, 0]],
            [dumux_points[u, 1], dumux_points[v, 1]],
            color=colors_dumux[idx],
            linewidth=0.8,
            alpha=0.9,
        )
    ax6.scatter(dumux_points[:, 0], dumux_points[:, 1], s=0.5, color="k", alpha=0.3)
    ax6.set_aspect("equal", adjustable="datalim")
    ax6.set_title("DuMuX final tracer on refined graph")
    ax6.set_xlabel("x (m)")
    ax6.set_ylabel("y (m)")
    fig6.tight_layout()
    fig6.savefig(figs_dir / "dumux_refined_graph_density.png", dpi=200)
    plt.close(fig6)

    # Config text summary
    cfg_lines = [
        f"DuMuX: edges={dumux_edges.shape[0]}, vertices={dumux_points.shape[0]}, T_end={dumux_times[-1]:.1f}s, outputs={len(dumux_times)}",
        f"Metric: edges={metric_edges.shape[0]}, vertices={metric_points.shape[0]}, num_particles={num_particles}, num_bins={num_bins}, dt={dt:.3e}, steps={steps_used}, horizon={time_horizon:.1f}s",
        f"Velocity scale={velocity_scale:.3g}, record_interval={record_interval}, total_wall_time={total_wall_metric:.3f}s",
        f"Refine segments={int(metric_data.get('refine_segments', -1))}",
        f"DuMuX runtime={dumux_runtime:.3f}s",
    ]
    fig5, ax5 = plt.subplots(figsize=(6, 3))
    ax5.axis("off")
    ax5.text(
        0.02,
        0.98,
        "\n".join(cfg_lines),
        va="top",
        ha="left",
        fontsize=10,
    )
    fig5.tight_layout()
    fig5.savefig(figs_dir / "dumux_metric_config.png", dpi=200)
    plt.close(fig5)

    # Error/runtimes summary figure
    fig7, ax7 = plt.subplots(figsize=(6, 3))
    ax7.axis("off")
    err_lines = [
        f"L2 mass error min/med/max: {errors.min():.3e} / {np.median(errors):.3e} / {errors.max():.3e}",
        f"Rel L2 min/med/max: {rel_errors.min():.3e} / {np.median(rel_errors):.3e} / {rel_errors.max():.3e}",
        f"Runtime (s): DuMuX {dumux_runtime:.3f}, Metric {total_wall_metric:.3f}",
    ]
    ax7.text(0.02, 0.98, "\n".join(err_lines), va="top", ha="left", fontsize=10)
    fig7.tight_layout()
    fig7.savefig(figs_dir / "dumux_metric_error_summary.png", dpi=200)
    plt.close(fig7)

    print("[dumux plots] Wrote individual figures to figs/:")
    print("  - dumux_metric_error_abs.png")
    print("  - dumux_metric_error_rel.png")
    print("  - dumux_metric_runtime.png")
    print("  - dumux_metric_graph_density.png")
    print("  - dumux_refined_graph_density.png")
    print("  - dumux_metric_config.png")
    print("  - dumux_metric_error_summary.png")
    print(f"[dumux plots] Runtime: DuMuX {dumux_runtime:.3f}s, Metric {total_wall_metric:.3f}s")


if __name__ == "__main__":
    main()
