"""
Create a GIF of per-edge density profiles (metric run) for top-mass edges.
"""

from pathlib import Path

import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.compare_config import config as compare_config
from experiments.dumux_tracer.compare_metric_to_dumux import match_times

REPO_ROOT = Path(inject_repo_into_sys_path())


def bin_centers(edge_length: float, num_bins: int) -> np.ndarray:
    return (np.arange(num_bins, dtype=float) + 0.5) / num_bins * edge_length


def edge_mass(density: np.ndarray, edge_lengths: np.ndarray) -> np.ndarray:
    num_bins = density.shape[2]
    dx = edge_lengths[:, None] / num_bins
    return (density * dx[None, :, :]).sum(axis=2)


def main(select_by: str = "metric") -> None:
    metric_npz = np.load(compare_config.metric_npz)
    density_metric = np.asarray(metric_npz["density"], dtype=float)  # (T, E, B)
    times_metric = np.asarray(metric_npz["times"], dtype=float)
    edge_lengths = np.asarray(metric_npz["edge_lengths"], dtype=float)

    dumux_npz = np.load(compare_config.dumux_npz)
    dumux_times = np.asarray(dumux_npz["times"], dtype=float)
    dumux_tracer = np.asarray(dumux_npz["tracer"], dtype=float)  # (T, E)

    pairs = match_times(times_metric, dumux_times, tol=compare_config.time_tolerance)
    if not pairs:
        raise RuntimeError("No matched times; widen tolerance.")

    # Align DuMuX per edge to per-bin (constant over bins)
    masses_metric = edge_mass(density_metric, edge_lengths)
    final_mass_metric = masses_metric[-1]
    dumux_mass = dumux_tracer[-1]
    if select_by == "dumux":
        # Use DuMuX final tracer as selector
        top_edges = np.argsort(-dumux_mass)[:4]
        top_edges = top_edges[top_edges < density_metric.shape[1]]
    else:
        top_edges = np.argsort(-final_mass_metric)[:4]

    if top_edges.size == 0:
        top_edges = np.argsort(-final_mass_metric)[:4]

    # Choose frames
    num_frames = min(40, len(times_metric))
    frame_indices = np.linspace(0, len(times_metric) - 1, num=num_frames, dtype=int)

    # Precompute y-limits for stability
    y_max = 0.0
    for idx in frame_indices:
        for edge_id in top_edges:
            profile = density_metric[idx, edge_id]
            y_max = max(y_max, float(profile.max()))
            d_idx = pairs[min(idx, len(pairs) - 1)][1]
            y_max = max(y_max, float(dumux_tracer[d_idx, edge_id]))
    if y_max <= 0:
        y_max = 1.0

    frames = []
    for idx in frame_indices:
        fig, axes = plt.subplots(2, 2, figsize=(8, 6), sharey=True)
        axes = axes.flatten()
        for ax, edge_id in zip(axes, top_edges):
            profile = density_metric[idx, edge_id]
            x = bin_centers(edge_lengths[edge_id], profile.shape[0]) * 1e6  # microns
            ax.plot(x, profile, lw=2, label="Metric")
            # Nearest DuMuX snapshot in time
            d_idx = pairs[idx][1] if idx < len(pairs) else pairs[-1][1]
            dumux_val = dumux_tracer[d_idx, edge_id]
            ax.plot(
                [x.min(), x.max()],
                [dumux_val, dumux_val],
                lw=2,
                ls="--",
                label="DuMuX",
            )
            ax.set_title(
                f"Edge {edge_id}, L={edge_lengths[edge_id]*1e6:.2f} µm\n"
                f"Mass(metric)={final_mass_metric[edge_id]:.3e} Mass(du)={dumux_mass[edge_id]:.3e}"
            )
            ax.set_xlabel("Position (µm)")
            ax.legend(fontsize=8)
            ax.set_ylim(0, y_max * 1.05)
        axes[0].set_ylabel("Density")
        axes[2].set_ylabel("Density")
        fig.suptitle(f"Metric edge profiles at t={times_metric[idx]:.1f}s")
        fig.tight_layout()
        fig.canvas.draw()
        buf = np.asarray(fig.canvas.buffer_rgba())[:, :, :3]
        frames.append(buf)
        plt.close(fig)

    suffix = "dumux" if select_by == "dumux" else "metric"
    out_path = REPO_ROOT / "figs" / f"dumux_metric_edge_profiles_{suffix}.gif"
    imageio.mimsave(out_path, frames, fps=30)
    print(f"[edge profile gif] Wrote {out_path}")

    # Report initial mass concentration (metric)
    init_mass = masses_metric[0]
    top_init = np.argsort(-init_mass)[:5]
    print(f"[edge profile gif] Initial mass top edges (metric IC, mode={select_by}):")
    for eid in top_init:
        print(
            f"  edge {eid}: mass={init_mass[eid]:.3e}, length={edge_lengths[eid]*1e6:.2f} µm"
        )


if __name__ == "__main__":
    import sys
    mode = "metric"
    if len(sys.argv) > 1:
        mode = sys.argv[1]
    main(select_by=mode)
