"""
MG-only network GIF for the advection (drift) case.

Env overrides:
  - DUMUX_METRIC_OUTPUT: path to MG NPZ (default: settings.metric_output)
  - DUMUX_INPUT_NPZ: path to preprocessed network NPZ (default: settings.dumux_input_npz)
  - DUMUX_PLOT_DIR: output directory (default: settings.plot_dir)
"""

from __future__ import annotations

import os
from pathlib import Path

import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.config_loader import load_settings

settings = load_settings()


def _downsample_indices(n: int, max_frames: int) -> np.ndarray:
    if n <= max_frames:
        return np.arange(n, dtype=int)
    return np.linspace(0, n - 1, num=max_frames, dtype=int)


def _gif_network(tracer, times, cell_edges, cell_points, total_mass, out: Path, max_frames: int = 80, title_prefix: str = ""):
    frames_idx = _downsample_indices(tracer.shape[0], max_frames)
    vmax = np.percentile(tracer, 99) if tracer.size else 1.0
    frames = []
    xy = cell_points[:, :2]
    for idx in frames_idx:
        fig, ax = plt.subplots(figsize=(6, 6))
        vals = tracer[idx]
        tmass = float(total_mass[idx])
        segments = []
        colors = []
        for (u, v), val in zip(cell_edges, vals):
            segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
            colors.append(val)
        lc = matplotlib.collections.LineCollection(
            segments,
            cmap="inferno",
            norm=matplotlib.colors.Normalize(vmin=0.0, vmax=vmax if vmax > 0 else 1.0),
            linewidths=2.0,
        )
        lc.set_array(np.asarray(colors))
        ax.add_collection(lc)
        ax.autoscale()
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("tracer density (normalized)")
        ax.set_title(f"{title_prefix} t = {times[idx]:.2f} s | total mass = {tmass:.3e}")
        fig.tight_layout()
        fig.canvas.draw()
        frames.append(np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8))
        plt.close(fig)

    out.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out, frames, fps=10)
    print(f"[mg-only advection] wrote {out} ({len(frames)} frames)")


def main() -> None:
    metric_path = Path(os.environ.get("DUMUX_METRIC_OUTPUT", settings.metric_output))
    input_path = Path(os.environ.get("DUMUX_INPUT_NPZ", settings.dumux_input_npz))
    out_dir = Path(os.environ.get("DUMUX_PLOT_DIR", settings.plot_dir))

    metric_npz = np.load(metric_path)
    input_npz = np.load(input_path)

    density = np.asarray(metric_npz["density"], dtype=float)  # (T, cells)
    times = np.asarray(metric_npz["times"], dtype=float)
    cell_lengths = np.asarray(metric_npz["cell_lengths"], dtype=float)
    cell_radii = np.asarray(metric_npz.get("cell_radii", []), dtype=float)
    cell_edges = np.asarray(input_npz["cell_edges"], dtype=int)
    cell_points = np.asarray(input_npz["cell_points"], dtype=float)

    if cell_radii.size == cell_lengths.size:
        area = np.pi * cell_radii * cell_radii
    else:
        area = np.ones_like(cell_lengths)

    molar_density = 1000.0 / 0.018  # kg/m^3 divided by molar mass (kg/mol)
    total_mass = 1.0
    tracer = density * (total_mass / (molar_density * np.maximum(area[None, :], 1e-20)))
    total_mass_series = (density * cell_lengths[None, :] * total_mass).sum(axis=1)

    _gif_network(
        tracer=tracer,
        times=times,
        cell_edges=cell_edges,
        cell_points=cell_points,
        total_mass=total_mass_series,
        out=out_dir / "advection_network_mg.gif",
        max_frames=80,
        title_prefix="MG drift |",
    )


if __name__ == "__main__":
    main()
