"""
Plot per-edge density profiles for DuMuX vs metric-graph on selected edges.
"""

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.compare_config import config as compare_config
from experiments.dumux_tracer.compare_metric_to_dumux import match_times

REPO_ROOT = Path(inject_repo_into_sys_path())


def _bin_centers(edge_length: float, num_bins: int) -> np.ndarray:
    return (np.arange(num_bins, dtype=float) + 0.5) / num_bins * edge_length


def main() -> None:
    figs_dir = REPO_ROOT / "figs"
    figs_dir.mkdir(parents=True, exist_ok=True)

    dumux_npz = np.load(compare_config.dumux_npz)
    metric_npz = np.load(compare_config.metric_npz)

    dumux_times = np.asarray(dumux_npz["times"], dtype=float)
    dumux_tracer = np.asarray(dumux_npz["tracer"], dtype=float)  # (T, edges)
    dumux_edges = np.asarray(dumux_npz["edges"], dtype=int)
    dumux_edge_lengths = np.asarray(dumux_npz["edge_lengths"], dtype=float)

    metric_times = np.asarray(metric_npz["times"], dtype=float)
    metric_density = np.asarray(metric_npz["density"], dtype=float)  # (T, edges, bins)
    metric_edges = np.asarray(metric_npz["edges"], dtype=int)
    metric_edge_lengths = np.asarray(metric_npz["edge_lengths"], dtype=float)

    if dumux_edges.shape[0] != metric_edges.shape[0]:
        raise RuntimeError(
            f"Edge count mismatch (dumux {dumux_edges.shape[0]} vs metric {metric_edges.shape[0]}). "
            "Run with matching refine settings."
        )

    pairs = match_times(metric_times, dumux_times, tol=compare_config.time_tolerance)
    if not pairs:
        raise RuntimeError("No matching times found; widen tolerance.")

    # Use the last matched snapshot (nearest to final time).
    metric_idx, dumux_idx = pairs[-1]
    target_total = float(dumux_tracer[dumux_idx].sum() * dumux_edge_lengths.mean())
    scaling = target_total if target_total > 0 else 1.0

    # Pick top-2 edges by DuMuX mass at final time.
    dumux_final = dumux_tracer[dumux_idx]
    top_edges = np.argsort(-dumux_final)[:2]

    for edge_id in top_edges:
        num_bins = metric_density.shape[2]
        x = _bin_centers(metric_edge_lengths[edge_id], num_bins)
        metric_profile = metric_density[metric_idx, edge_id] * scaling
        dumux_profile = np.full_like(x, dumux_final[edge_id])

        fig, ax = plt.subplots(figsize=(6, 4))
        ax.plot(x * 1e6, metric_profile, label="Metric", lw=2)
        ax.plot(x * 1e6, dumux_profile, label="DuMuX (const.)", lw=2, ls="--")
        ax.set_xlabel("Position along edge (µm)")
        ax.set_ylabel("Density")
        ax.set_title(
            f"Edge {edge_id} profile at t≈{metric_times[metric_idx]:.1f}s\n"
            f"(length={metric_edge_lengths[edge_id]*1e6:.2f} µm, bins={num_bins})"
        )
        ax.legend()
        fig.tight_layout()
        out_path = figs_dir / f"dumux_metric_edge_profile_{edge_id}.png"
        fig.savefig(out_path, dpi=200)
        plt.close(fig)
        print(f"[edge profile] wrote {out_path}")


if __name__ == "__main__":
    main()
