"""
Generate GIFs for all runs in a drift (advection) manifest.

- For each MG run that succeeded: create a network GIF of tracer density.
- If an FVM NPZ exists: create a network GIF as well.

Outputs live in figs/revision/drift/<timestamp>/.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict

from tqdm import tqdm
import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
FIGS_ROOT = REPO_ROOT / "figs" / "revision" / "drift"


def _load_manifest(path: Path) -> Dict[str, Any]:
    return json.loads(path.read_text())


def _downsample_indices(n: int, max_frames: int = 80) -> np.ndarray:
    if n <= max_frames:
        return np.arange(n, dtype=int)
    return np.linspace(0, n - 1, num=max_frames, dtype=int)


def _gif_network(
    tracer: np.ndarray,
    times: np.ndarray,
    cell_edges: np.ndarray,
    cell_points: np.ndarray,
    total_mass: np.ndarray,
    out: Path,
    title_prefix: str,
    max_frames: int = 80,
    show_progress: bool = True,
) -> None:
    frames_idx = _downsample_indices(tracer.shape[0], max_frames)
    vmax = np.percentile(tracer, 99) if tracer.size else 1.0
    frames = []
    xy = cell_points[:, :2]
    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=vmax if vmax > 0 else 1.0)
    iterator = tqdm(frames_idx, desc=f"gif {out.name}", disable=not show_progress)
    for idx in iterator:
        fig, ax = plt.subplots(figsize=(6, 6))
        vals = tracer[idx]
        tmass = float(total_mass[idx])
        segments = []
        colors = []
        for (u, v), val in zip(cell_edges, vals):
            segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
            colors.append(val)
        lc = matplotlib.collections.LineCollection(
            segments,
            cmap="inferno",
            norm=norm,
            linewidths=1.5,
        )
        lc.set_array(np.asarray(colors))
        ax.add_collection(lc)
        ax.autoscale()
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("tracer density (normalized)")
        ax.set_title(f"{title_prefix} t = {times[idx]:.1f} s | total mass = {tmass:.3e}")
        fig.tight_layout()
        fig.canvas.draw()
        frames.append(np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8))
        plt.close(fig)

    out.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out, frames, fps=10)
    print(f"[drift gifs] wrote {out} ({len(frames)} frames)")


def _load_and_make_gif(
    npz_path: Path,
    input_npz: Path,
    title_prefix: str,
    out: Path,
    show_progress: bool = True,
) -> None:
    data = np.load(npz_path)
    input_data = np.load(input_npz)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    if "density" in data:
        density = np.asarray(data["density"], dtype=float)
        times = np.asarray(data["times"], dtype=float)
        cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
        if cell_radii.size == cell_lengths.size:
            area = np.pi * cell_radii * cell_radii
        else:
            area = np.ones_like(cell_lengths)
        tracer = density / np.maximum(area[None, :], 1e-20)
        total_mass_series = (density * cell_lengths[None, :]).sum(axis=1)
    else:
        tracer_raw = np.asarray(data["tracer"], dtype=float)
        times = np.asarray(data["times"], dtype=float)
        cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
        if cell_radii.size == cell_lengths.size:
            area = np.pi * cell_radii * cell_radii
        else:
            area = np.ones_like(cell_lengths)
        molar_density = 1000.0 / 0.018
        tracer = tracer_raw * (1.0 / (molar_density * np.maximum(area[None, :], 1e-20)))
        total_mass_series = (tracer_raw * cell_lengths[None, :] * area[None, :] * molar_density).sum(axis=1)

    cell_edges = np.asarray(input_data["cell_edges"], dtype=int)
    cell_points = np.asarray(input_data["cell_points"], dtype=float)
    _gif_network(
        tracer,
        times,
        cell_edges,
        cell_points,
        total_mass_series,
        out=out,
        title_prefix=title_prefix,
        show_progress=show_progress,
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate GIFs for drift runs in a manifest.")
    parser.add_argument("--manifest", type=Path, required=True, help="Path to drift manifest.json")
    parser.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bars")
    args = parser.parse_args()

    manifest = _load_manifest(args.manifest)
    base_dir = args.manifest.parent

    if (base_dir / "network_input.npz").exists():
        input_npz = base_dir / "network_input.npz"
    elif (base_dir / "dumux_advection_network_input.npz").exists():
        input_npz = base_dir / "dumux_advection_network_input.npz"
    else:
        cand = REPO_ROOT / "data" / "dumux_advection_network_input.npz"
        if not cand.exists():
            raise FileNotFoundError("No input NPZ found (network_input.npz or dumux_advection_network_input.npz)")
        input_npz = cand

    out_dir = FIGS_ROOT / manifest["timestamp"]
    out_dir.mkdir(parents=True, exist_ok=True)

    for run in manifest["runs"]:
        dt_tag = str(run["dt"]).replace(".", "p")
        # FVM (only if present and successful)
        fvm_npz_name = run.get("fvm", {}).get("npz")
        if fvm_npz_name:
            fvm_npz = base_dir / fvm_npz_name
            if fvm_npz.exists():
                _load_and_make_gif(
                    fvm_npz,
                    input_npz,
                    title_prefix=f"FVM drift dt={run['dt']} |",
                    out=out_dir / f"fvm_dt{dt_tag}.gif",
                    show_progress=not args.no_progress,
                )
        # MG
        mg = run.get("mg", {})
        if mg.get("status") != "ok":
            continue
        mg_npz = base_dir / mg["npz"]
        if not mg_npz.exists():
            print(f"[drift gifs] missing MG NPZ {mg_npz}, skipping")
            continue
        title = f"MG drift dt={run['dt']} |"
        _load_and_make_gif(
            mg_npz,
            input_npz,
            title_prefix=title,
            out=out_dir / f"mg_dt{dt_tag}.gif",
            show_progress=not args.no_progress,
        )


if __name__ == "__main__":
    main()
