"""
Create a GIF of metric-graph density evolution on a subset of edges.
"""

from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.network import load_dumux_network
from experiments.dumux_tracer.metric_graph_config import config as metric_config

REPO_ROOT = Path(inject_repo_into_sys_path())


def edge_mass(density: np.ndarray, edge_lengths: np.ndarray) -> np.ndarray:
    """Integrate per-edge mass from density (edges, bins)."""
    num_bins = density.shape[1]
    dx = edge_lengths[:, None] / num_bins
    return (density * dx).sum(axis=1)


def main() -> None:
    metric_npz = metric_config.output_path
    net_path = metric_config.dgf_path
    out_path = REPO_ROOT / "figs" / "dumux_metric_density_evolution.gif"

    data = np.load(metric_npz)
    density = np.asarray(data["density"], dtype=float)
    times = np.asarray(data["times"], dtype=float)
    edge_lengths = np.asarray(data["edge_lengths"], dtype=float)

    net = load_dumux_network(
        net_path,
        refine_segments=metric_config.refine_segments,
    )
    coords = net.points
    edges = net.edges
    radii = net.radii

    subset = min(200, edges.shape[0])
    top_indices = np.argsort(-radii)[:subset]

    masses = np.stack([edge_mass(d, edge_lengths) for d in density])
    if masses.shape[1] != edges.shape[0]:
        raise RuntimeError(
            f"Edge count mismatch (density edges={masses.shape[1]}, graph edges={edges.shape[0]}). "
            "Ensure refine settings match."
        )
    masses = masses[:, top_indices]
    vmax = masses.max() + 1e-12

    selected_edges = edges[top_indices]
    times_to_plot = np.linspace(0, len(times) - 1, num=min(len(times), 20), dtype=int)

    frames = []
    for idx in times_to_plot:
        fig, ax = plt.subplots(figsize=(6, 6))
        frame_mass = masses[idx]
        colors = plt.cm.inferno(frame_mass / vmax)
        for e_idx, (u, v) in enumerate(selected_edges):
            ax.plot(
                [coords[u, 0], coords[v, 0]],
                [coords[u, 1], coords[v, 1]],
                color=colors[e_idx],
                linewidth=1.5,
                alpha=0.9,
            )
        ax.scatter(coords[:, 0], coords[:, 1], s=1, color="k", alpha=0.2)
        ax.set_aspect("equal", adjustable="datalim")
        ax.set_title(f"Metric density (subset {subset} edges) at t={times[idx]:.1f}s")
        ax.axis("off")
        fig.canvas.draw()
        buf = np.asarray(fig.canvas.buffer_rgba())
        img = buf[:, :, :3].copy()
        frames.append(img)
        plt.close(fig)

    imageio.mimsave(out_path, frames, fps=4)
    print(f"[dumux density gif] Wrote {out_path}")


if __name__ == "__main__":
    main()
