"""
Make quick GIFs for the conservative DuMuX 1D tracer baseline.

Outputs:
    figs/dumux_baseline_edge_profiles.gif    # per-edge profiles (top edges by final mass)
    figs/dumux_baseline_network.gif          # 2D network layout colored by density
"""

from pathlib import Path

import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.config import config as dumux_config


def _load_dumux(path: Path):
    data = np.load(path)
    tracer = np.asarray(data["tracer"], dtype=float)  # (time, cells)
    times = np.asarray(data["times"], dtype=float)
    cell_edges = np.asarray(data["cell_edges"], dtype=int)
    cell_points = np.asarray(data["cell_points"], dtype=float)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    cell_radii = np.asarray(data["cell_radii"], dtype=float) if "cell_radii" in data else None
    cell_to_edge = np.asarray(data["cell_to_edge"], dtype=int)
    cell_start = np.asarray(data["cell_start"], dtype=float)
    orig_edges = np.asarray(data["orig_edges"], dtype=int)
    return tracer, times, cell_edges, cell_points, cell_lengths, cell_radii, cell_to_edge, cell_start, orig_edges


def _cell_mass(tracer: np.ndarray, lengths: np.ndarray, radii: np.ndarray | None) -> np.ndarray:
    """Mass per cell."""
    if radii is not None and radii.size == tracer.shape[-1]:
        area = np.pi * radii * radii
        molar_density = 1000.0 / 0.018  # consistent with params.input defaults
        return tracer * molar_density * area * lengths
    return tracer * lengths


def _edge_mass_from_cells(
    tracer: np.ndarray,
    cell_lengths: np.ndarray,
    cell_radii: np.ndarray | None,
    cell_to_edge: np.ndarray,
    num_edges: int,
) -> np.ndarray:
    """Aggregate cell masses to original edges."""
    cell_mass = _cell_mass(tracer, cell_lengths, cell_radii)
    out = np.zeros((tracer.shape[0], num_edges), dtype=float)
    for t in range(tracer.shape[0]):
        out[t] = np.bincount(cell_to_edge, weights=cell_mass[t], minlength=num_edges)
    return out


def _downsample_indices(n: int, max_frames: int) -> np.ndarray:
    if n <= max_frames:
        return np.arange(n, dtype=int)
    return np.linspace(0, n - 1, num=max_frames, dtype=int)


def _plot_edge_profiles_gif(
    tracer: np.ndarray,
    times: np.ndarray,
    cell_edges: np.ndarray,
    cell_lengths: np.ndarray,
    cell_radii: np.ndarray | None,
    cell_to_edge: np.ndarray,
    cell_start: np.ndarray,
    orig_edges: np.ndarray,
    out_path: Path,
    k: int = 6,
    max_frames: int = 60,
):
    num_edges = orig_edges.shape[0]
    edge_mass = _edge_mass_from_cells(tracer, cell_lengths, cell_radii, cell_to_edge, num_edges)
    top_edges = np.argsort(-edge_mass[-1])[:k]

    # Precompute global max over selected chains for stable y-limits
    vmax = 0.0
    for edge_id in top_edges:
        mask = cell_to_edge == edge_id
        vmax = max(vmax, float(tracer[:, mask].max()))
    if vmax <= 0:
        vmax = 1.0

    frames = []
    frame_indices = _downsample_indices(tracer.shape[0], max_frames)

    for step in frame_indices:
        fig, axes = plt.subplots(len(top_edges), 1, figsize=(6, 1.6 * len(top_edges)), sharex=False)
        if len(top_edges) == 1:
            axes = [axes]
        total_mass = float(_cell_mass(tracer[step], cell_lengths, cell_radii).sum())
        edge_mass_step = edge_mass[step]
        for ax, edge_id in zip(axes, top_edges):
            mask = cell_to_edge == edge_id
            if not np.any(mask):
                continue
            order = np.argsort(cell_start[mask])
            seg_lengths = cell_lengths[mask][order]
            vals = tracer[step, mask][order]
            bounds = np.concatenate([[0.0], np.cumsum(seg_lengths)])
            y = np.concatenate([vals, [vals[-1]]])
            ax.step(bounds, y, where="post", color="tab:blue", linewidth=2)
            ax.set_ylim(0, vmax * 1.05)
            ax.set_ylabel(f"edge {edge_id}")
            ax.set_title(
                f"cells={len(seg_lengths)}, mass={edge_mass_step[edge_id]:.3e} mol"
            )
        axes[-1].set_xlabel("arc length [m]")
        fig.suptitle(f"t = {times[step]:.2f} s | total mass = {total_mass:.3e} mol")
        fig.tight_layout()
        fig.canvas.draw()
        image = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
        frames.append(image)
        plt.close(fig)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out_path, frames, fps=30)
    print(f"[dumux gif] wrote {out_path} ({len(frames)} frames)")


def _plot_network_gif(tracer, times, cell_edges, cell_points, cell_lengths, cell_radii, out_path: Path, max_frames: int = 60):
    """
    2D projection (x-y) of the network with per-edge color = tracer value.
    """
    frames = []
    frame_indices = _downsample_indices(tracer.shape[0], max_frames)
    vmax = np.percentile(tracer, 99) if tracer.size else 1.0
    xy = cell_points[:, :2]
    for step in frame_indices:
        fig, ax = plt.subplots(figsize=(6, 6))
        vals = tracer[step]
        total_mass = float(_cell_mass(tracer[step], cell_lengths, cell_radii).sum())
        segments = []
        colors = []
        for (u, v), val in zip(cell_edges, vals):
            segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
            colors.append(val)
        lc = matplotlib.collections.LineCollection(
            segments,
            cmap="inferno",
            norm=matplotlib.colors.Normalize(vmin=0.0, vmax=vmax if vmax > 0 else 1.0),
            linewidths=2.0,
        )
        lc.set_array(np.asarray(colors))
        ax.add_collection(lc)
        ax.autoscale()
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("tracer (mole fraction)")
        ax.set_title(f"t = {times[step]:.2f} s | total mass = {total_mass:.3e} mol")
        fig.tight_layout()
        fig.canvas.draw()
        image = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
        frames.append(image)
        plt.close(fig)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out_path, frames, fps=10)
    print(f"[dumux gif] wrote {out_path} ({len(frames)} frames)")


def main() -> None:
    tracer, times, cell_edges, cell_points, cell_lengths, cell_radii, cell_to_edge, cell_start, orig_edges = _load_dumux(
        dumux_config.output_path
    )
    total_mass_series = _cell_mass(tracer, cell_lengths, cell_radii).sum(axis=1)
    scale = 1.0
    if total_mass_series[0] > 0:
        scale = 1.0 / total_mass_series[0]
        tracer = tracer * scale
        total_mass_series *= scale

    print(
        f"[dumux baseline] mass conservation check: "
        f"min={total_mass_series.min():.4e}, max={total_mass_series.max():.4e}, "
        f"drift={(total_mass_series.max()-total_mass_series.min()):.4e}"
    )

    figs_dir = Path("figs")
    _plot_edge_profiles_gif(
        tracer,
        times,
        cell_edges,
        cell_lengths,
        cell_radii,
        cell_to_edge,
        cell_start,
        orig_edges,
        figs_dir / "dumux_baseline_edge_profiles.gif",
    )
    _plot_network_gif(
        tracer,
        times,
        cell_edges,
        cell_points,
        cell_lengths,
        cell_radii,
        figs_dir / "dumux_baseline_network.gif",
    )


if __name__ == "__main__":
    main()
