"""
Metric-graph–only visualizations. Useful when DuMuX fails (e.g. synthetic high drift).

Outputs (under settings.plot_dir):
  - network_mg_only.gif          : 2D layout colored by per-edge mass
  - edge_profiles_mg_only.gif    : step profiles on top-mass edges
"""

from pathlib import Path
from typing import Iterable
import os

import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.config_loader import load_settings

settings = load_settings()
METRIC_PATH_OVERRIDE = os.environ.get("DUMUX_METRIC_OUTPUT")
INPUT_PATH_OVERRIDE = os.environ.get("DUMUX_INPUT_NPZ")
REPO_ROOT = Path(inject_repo_into_sys_path())


def _downsample_indices(n: int, max_frames: int) -> np.ndarray:
    if n <= max_frames:
        return np.arange(n, dtype=int)
    return np.linspace(0, n - 1, num=max_frames, dtype=int)


def _edge_mass_per_timestep(cell_mass: np.ndarray, cell_to_edge: np.ndarray, num_edges: int) -> np.ndarray:
    out = np.zeros((cell_mass.shape[0], num_edges), dtype=float)
    for t in range(cell_mass.shape[0]):
        out[t] = np.bincount(cell_to_edge, weights=cell_mass[t], minlength=num_edges)
    return out


def _select_top_edges(edge_mass: np.ndarray, k: int) -> np.ndarray:
    final_mass = edge_mass[-1]
    order = np.argsort(-final_mass)
    return order[: min(k, order.size)]


def _plot_network_gif(
    edge_mass: np.ndarray,
    times: np.ndarray,
    points: np.ndarray,
    edges: np.ndarray,
    out_path: Path,
    max_frames: int = 60,
) -> None:
    frame_indices = _downsample_indices(edge_mass.shape[0], max_frames)
    vmax = np.percentile(edge_mass, 99) if edge_mass.size else 1.0
    frames: list[np.ndarray] = []
    xy = points[:, :2]

    for idx in frame_indices:
        fig, ax = plt.subplots(figsize=(6, 6))
        masses = edge_mass[idx]
        segments = []
        colors = []
        for (u, v), m in zip(edges, masses):
            segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
            colors.append(m)
        lc = matplotlib.collections.LineCollection(
            segments,
            cmap="inferno",
            norm=matplotlib.colors.Normalize(vmin=0.0, vmax=vmax if vmax > 0 else 1.0),
            linewidths=2.0,
        )
        lc.set_array(np.asarray(colors))
        ax.add_collection(lc)
        ax.scatter(xy[:, 0], xy[:, 1], s=4, color="blue", alpha=0.7, zorder=3)
        ax.autoscale()
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("edge mass (arb.)")
        ax.set_title(f"Metric tracer (MG only) t={times[idx]:.2f}s | total mass={masses.sum():.3e}")
        fig.tight_layout()
        fig.canvas.draw()
        frames.append(np.asarray(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8))
        plt.close(fig)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out_path, frames, fps=10)
    print(f"[metric-only] wrote {out_path} ({len(frames)} frames)")


def _plot_edge_profiles_gif(
    density: np.ndarray,
    times: np.ndarray,
    cell_lengths: np.ndarray,
    cell_start: np.ndarray,
    cell_to_edge: np.ndarray,
    top_edges: Iterable[int],
    out_path: Path,
    max_frames: int = 80,
) -> None:
    frame_indices = _downsample_indices(density.shape[0], max_frames)
    top_edges = list(top_edges)

    # Precompute max for consistent y-limits
    y_max = 0.0
    for edge_id in top_edges:
        mask = cell_to_edge == edge_id
        if not np.any(mask):
            continue
        y_max = max(y_max, float(density[:, mask].max()))
    if y_max <= 0:
        y_max = 1.0

    frames: list[np.ndarray] = []
    rows = int(np.ceil(len(top_edges) / 4))
    cols = min(4, len(top_edges))

    for idx in frame_indices:
        fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 2 * rows), squeeze=False, sharey=True)
        axes_flat = axes.flatten()
        for ax, edge_id in zip(axes_flat, top_edges):
            mask = cell_to_edge == edge_id
            if not np.any(mask):
                ax.axis("off")
                continue
            order = np.argsort(cell_start[mask])
            starts = cell_start[mask][order]
            starts = starts - starts.min()  # anchor at 0
            lens = cell_lengths[mask][order]
            vals = density[idx, mask][order]
            bounds = np.concatenate([starts, [starts[-1] + lens[-1]]])
            y = np.concatenate([vals, [vals[-1]]])
            ax.step(bounds, y, where="post", color="tab:blue", linewidth=2)
            ax.set_title(f"edge {edge_id} | cells={lens.size}")
            ax.set_ylim(0, y_max * 1.05)
            ax.set_xlabel("arc length [m]")
        axes_flat[0].set_ylabel("density [prob/length]")
        # Hide any unused subplots
        for ax in axes_flat[len(top_edges) :]:
            ax.axis("off")
        fig.suptitle(f"Metric edge profiles (MG only) t={times[idx]:.2f}s")
        fig.tight_layout()
        fig.canvas.draw()
        frames.append(np.asarray(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8))
        plt.close(fig)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out_path, frames, fps=15)
    print(f"[metric-only] wrote {out_path} ({len(frames)} frames)")


def main() -> None:
    metric_path = Path(METRIC_PATH_OVERRIDE) if METRIC_PATH_OVERRIDE else settings.metric_output
    input_path = Path(INPUT_PATH_OVERRIDE) if INPUT_PATH_OVERRIDE else settings.dumux_input_npz
    metric_npz = np.load(metric_path)
    input_npz = np.load(input_path)

    density = np.asarray(metric_npz["density"], dtype=float)  # (T, cells)
    times = np.asarray(metric_npz["times"], dtype=float)
    cell_lengths = np.asarray(metric_npz["cell_lengths"], dtype=float)
    cell_to_edge = np.asarray(metric_npz["cell_to_edge"], dtype=int)
    cell_start = np.asarray(metric_npz["cell_start"], dtype=float)
    drift = np.asarray(metric_npz.get("drift_coeffs", []), dtype=float)
    cell_radii = np.asarray(metric_npz.get("cell_radii", []), dtype=float)

    orig_points = np.asarray(input_npz["orig_points"], dtype=float)
    orig_edges = np.asarray(input_npz["orig_edges"], dtype=int)

    # Convert metric density (prob/length) -> tracer (mole fraction) for comparability with DuMuX.
    # Assume total mass ~1 (probability mass) if no molar scaling is stored.
    if cell_radii.size == cell_lengths.size:
        area = np.pi * cell_radii * cell_radii
    else:
        area = np.ones_like(cell_lengths)

    # Match the tracer scaling used in diffusion overlays: tracer (mole fraction) =
    # density (prob/length) * total_mass / (molar_density * area).
    MOLAR_DENSITY = 1000.0 / 0.018  # kg/m^3 divided by molar mass (kg/mol)
    total_mass = 1.0  # probability mass; keep consistent across runs
    tracer = density * (total_mass / (MOLAR_DENSITY * np.maximum(area[None, :], 1e-20)))

    # Mass per cell using per-length density and cross-sectional area.
    cell_mass = density * cell_lengths[None, :] * total_mass
    edge_mass = _edge_mass_per_timestep(cell_mass, cell_to_edge, orig_edges.shape[0])

    # Prefer edges with strongest drift (to highlight advection effects); fallback to mass if absent.
    if drift.size == edge_mass.shape[1] and drift.size > 0:
        counts = np.bincount(cell_to_edge, minlength=drift.size)
        candidates = [i for i, c in enumerate(counts) if c >= 3]
        order = sorted(candidates, key=lambda i: (abs(drift[i]), counts[i]), reverse=True)
        top_edges = np.array(order[: min(12, len(order))], dtype=int)
        if top_edges.size == 0:
            top_edges = _select_top_edges(edge_mass, k=12)
    else:
        top_edges = _select_top_edges(edge_mass, k=12)

    num_particles = int(metric_npz.get("num_particles", -1))
    suffix_parts = []
    if num_particles > 0:
        suffix_parts.append(f"p{num_particles//1000}k")
    flux = getattr(settings, "flux_scale", None)
    if flux is not None:
        suffix_parts.append(f"flux{flux:.0e}")
    suffix = "_" + "_".join(suffix_parts) if suffix_parts else ""

    out_dir = settings.plot_dir
    out_dir.mkdir(parents=True, exist_ok=True)
    _plot_network_gif(edge_mass, times, orig_points, orig_edges, out_dir / f"network_mg_only{suffix}.gif")
    _plot_edge_profiles_gif(
        density,
        times,
        cell_lengths,
        cell_start,
        cell_to_edge,
        top_edges,
        out_dir / f"edge_profiles_mg_only{suffix}.gif",
    )


if __name__ == "__main__":
    main()
