"""
Generate GIFs for all runs in a diffusion manifest (FVM + MG).

For each MG NPZ: create a network GIF of tracer density.
For the FVM NPZ: create a network GIF as well.

Outputs go to figs/revision/diffusion/<run_id>/.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict, List

from tqdm import tqdm
import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
FIGS_ROOT = REPO_ROOT / "figs" / "revision" / "diffusion"


def _load_manifest(path: Path) -> Dict[str, Any]:
    return json.loads(path.read_text())


def _downsample_indices(n: int, max_frames: int = 80) -> np.ndarray:
    if n <= max_frames:
        return np.arange(n, dtype=int)
    return np.linspace(0, n - 1, num=max_frames, dtype=int)


def _gif_network(
    tracer: np.ndarray,
    times: np.ndarray,
    cell_edges: np.ndarray,
    cell_points: np.ndarray,
    total_mass: np.ndarray,
    out: Path,
    title_prefix: str,
    max_frames: int = 80,
    show_progress: bool = True,
) -> None:
    frames_idx = _downsample_indices(tracer.shape[0], max_frames)
    vmax = np.percentile(tracer, 99) if tracer.size else 1.0
    frames = []
    xy = cell_points[:, :2]
    iterator = tqdm(frames_idx, desc=f"gif {out.name}", disable=not show_progress)
    for idx in iterator:
        fig, ax = plt.subplots(figsize=(6, 6))
        vals = tracer[idx]
        tmass = float(total_mass[idx])
        segments = []
        colors = []
        for (u, v), val in zip(cell_edges, vals):
            segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
            colors.append(val)
        lc = matplotlib.collections.LineCollection(
            segments,
            cmap="inferno",
            norm=matplotlib.colors.Normalize(vmin=0.0, vmax=vmax if vmax > 0 else 1.0),
            linewidths=1.5,
        )
        lc.set_array(np.asarray(colors))
        ax.add_collection(lc)
        ax.autoscale()
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("tracer density (normalized)")
        ax.set_title(f"{title_prefix} t = {times[idx]:.1f} s | total mass = {tmass:.3e}")
        fig.tight_layout()
        fig.canvas.draw()
        frames.append(np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8))
        plt.close(fig)

    out.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out, frames, fps=10)
    print(f"[diffusion gifs] wrote {out} ({len(frames)} frames)")


def _load_and_make_gif(npz_path: Path, input_npz: Path, title_prefix: str, out: Path) -> None:
    data = np.load(npz_path)
    input_data = np.load(input_npz)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    if "density" in data:
        density = np.asarray(data["density"], dtype=float)
        times = np.asarray(data["times"], dtype=float)
        cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
        if cell_radii.size == cell_lengths.size:
            area = np.pi * cell_radii * cell_radii
        else:
            area = np.ones_like(cell_lengths)
        # Convert per-length density to concentration (per volume) for visualization.
        tracer = density / np.maximum(area[None, :], 1e-20)
        total_mass_series = (density * cell_lengths[None, :]).sum(axis=1)
    else:
        tracer_raw = np.asarray(data["tracer"], dtype=float)
        times = np.asarray(data["times"], dtype=float)
        cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
        if cell_radii.size == cell_lengths.size:
            area = np.pi * cell_radii * cell_radii
        else:
            area = np.ones_like(cell_lengths)
        molar_density = 1000.0 / 0.018
        tracer = tracer_raw * (1.0 / (molar_density * np.maximum(area[None, :], 1e-20)))
        total_mass_series = (tracer_raw * cell_lengths[None, :] * area[None, :] * molar_density).sum(axis=1)

    cell_edges = np.asarray(input_data["cell_edges"], dtype=int)
    cell_points = np.asarray(input_data["cell_points"], dtype=float)
    _gif_network(tracer, times, cell_edges, cell_points, total_mass_series, out=out, title_prefix=title_prefix)


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate GIFs for diffusion runs in a manifest.")
    parser.add_argument("--manifest", type=Path, required=True, help="Path to diffusion manifest.json")
    parser.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bars")
    args = parser.parse_args()

    manifest = _load_manifest(args.manifest)
    base_dir = args.manifest.parent
    # Accept either name; fallback to repo-level input if missing.
    if (base_dir / "network_input.npz").exists():
        input_npz = base_dir / "network_input.npz"
    elif (base_dir / "dumux_network_input.npz").exists():
        input_npz = base_dir / "dumux_network_input.npz"
    else:
        # Fallback: use the latest built input under data/
        cand = REPO_ROOT / "data" / "dumux_network_input.npz"
        if not cand.exists():
            raise FileNotFoundError("No input NPZ found (network_input.npz or dumux_network_input.npz)")
        input_npz = cand

    out_dir = FIGS_ROOT / manifest["timestamp"]
    out_dir.mkdir(parents=True, exist_ok=True)

    # FVM GIF
    fvm_npz = base_dir / manifest["fvm"]["npz"]
    if fvm_npz.exists():
        _load_and_make_gif(
            fvm_npz,
            input_npz,
            title_prefix="FVM diffusion |",
            out=out_dir / "fvm_diffusion.gif",
            show_progress=not args.no_progress,
        )
    else:
        print("[diffusion gifs] FVM NPZ missing, skipping FVM gif")

    # MG GIFs
    for run in manifest["mg_runs"]:
        if run.get("status") != "ok":
            continue
        mg_npz = base_dir / run["npz"]
        if not mg_npz.exists():
            print(f"[diffusion gifs] missing MG NPZ {mg_npz}, skipping")
            continue
        title = f"MG diffusion p={run['particles']} seed={run['seed']} |"
        out_path = out_dir / f"mg_p{run['particles']}_s{run['seed']}.gif"
        _load_and_make_gif(
            mg_npz,
            input_npz,
            title_prefix=title,
            out=out_path,
            show_progress=not args.no_progress,
        )


if __name__ == "__main__":
    main()
