"""
Diffusion-only comparison: DuMuX (FVM) vs CUDA metric-graph sampler.

Outputs (under DUMUX_PLOT_DIR, default: figs/):
  - rel_error.png               : relative L2 mass error vs time
  - total_mass.png              : total mass over time (both solvers, normalized)
  - network_fvm.gif             : DuMuX tracer on the 2D layout
  - network_mg.gif              : metric-graph tracer on the 2D layout
  - edge_profiles_overlay.gif   : per-edge profiles (DuMuX vs metric) for top-mass edges
"""

from pathlib import Path

import imageio.v2 as imageio
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.config_loader import load_compare_config, load_settings

compare_config = load_compare_config()
settings = load_settings()
from experiments.dumux_tracer.compare_metric_to_dumux import (
    MOLAR_DENSITY,
    _align_edges,
    _cell_mass,
    match_times,
)


def _downsample_indices(n: int, max_frames: int) -> np.ndarray:
    if n <= max_frames:
        return np.arange(n, dtype=int)
    return np.linspace(0, n - 1, num=max_frames, dtype=int)


def load_data():
    dumux = np.load(compare_config.dumux_npz)
    tracer_du = np.asarray(dumux["tracer"], dtype=float)  # (T, cells)
    times_du = np.asarray(dumux["times"], dtype=float)
    cell_edges = np.asarray(dumux["cell_edges"], dtype=int)
    cell_points = np.asarray(dumux["cell_points"], dtype=float)
    cell_lengths = np.asarray(dumux["cell_lengths"], dtype=float)
    cell_radii = np.asarray(dumux.get("cell_radii", []), dtype=float)
    cell_to_edge = np.asarray(dumux["cell_to_edge"], dtype=int)
    cell_start = np.asarray(dumux["cell_start"], dtype=float)
    num_orig_edges = int(np.asarray(dumux["orig_edges"]).shape[0])

    mass_du = _cell_mass(tracer_du, cell_lengths, cell_radii)
    total_du = mass_du.sum(axis=1)
    total0 = float(total_du[0]) if total_du.size else 1.0

    metric = np.load(compare_config.metric_npz)
    density_mg = np.asarray(metric["density"], dtype=float)  # (T, cells)
    times_mg = np.asarray(metric["times"], dtype=float)

    # Mass per cell: density [prob/len] * length * total mass.
    mass_mg = density_mg * cell_lengths[None, :] * total0
    total_mg = mass_mg.sum(axis=1)

    # Convert metric density to tracer (mole fraction) for visualization.
    area = np.ones_like(cell_lengths)
    if cell_radii.size == cell_lengths.size:
        area = np.pi * cell_radii * cell_radii
    tracer_mg = density_mg * (total0 / (MOLAR_DENSITY * area))[None, :]

    # Aggregate mass to original edges for edge selection.
    edge_mass_du = np.zeros((tracer_du.shape[0], num_orig_edges), dtype=float)
    edge_mass_mg = np.zeros_like(edge_mass_du)
    num_t = min(tracer_du.shape[0], mass_mg.shape[0])
    edge_mass_du = edge_mass_du[:num_t]
    edge_mass_mg = edge_mass_mg[:num_t]
    tracer_du = tracer_du[:num_t]
    tracer_mg = tracer_mg[:num_t]
    mass_du = mass_du[:num_t]
    mass_mg = mass_mg[:num_t]
    times_du = times_du[:num_t]
    times_mg = times_mg[:num_t]

    for t in range(num_t):
        edge_mass_du[t] = np.bincount(cell_to_edge, weights=mass_du[t], minlength=num_orig_edges)
        edge_mass_mg[t] = np.bincount(cell_to_edge, weights=mass_mg[t], minlength=num_orig_edges)

    return dict(
        tracer_du=tracer_du,
        tracer_mg=tracer_mg,
        times_du=times_du,
        times_mg=times_mg,
        mass_du=mass_du,
        mass_mg=mass_mg,
        total_du=total_du,
        total_mg=total_mg,
        cell_edges=cell_edges,
        cell_points=cell_points,
        cell_lengths=cell_lengths,
        cell_to_edge=cell_to_edge,
        cell_start=cell_start,
        edge_mass_du=edge_mass_du,
        edge_mass_mg=edge_mass_mg,
    )


def plot_errors(data, out_dir: Path) -> None:
    flux = float(getattr(settings, "flux_scale", 0.0))
    has_flux = flux != 0.0
    mode = "advection-diffusion" if has_flux else "diffusion"
    pairs = match_times(data["times_mg"], data["times_du"], tol=compare_config.time_tolerance)
    rel = []
    abs_err = []
    t_plot = []
    for m_idx, d_idx in pairs:
        err = np.linalg.norm(data["mass_mg"][m_idx] - data["mass_du"][d_idx])
        tot = max(data["total_du"][d_idx], 1e-20)
        rel.append(err / tot)
        abs_err.append(err)
        t_plot.append(data["times_du"][d_idx])

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(t_plot, rel, lw=2)
    ax.set_xlabel("time [s]")
    ax.set_ylabel("relative L2 mass error")
    ax.set_yscale("log")
    title = f"DuMuX vs metric ({mode})"
    if has_flux:
        title += f" | flux_scale={flux:.1e}"
    ax.set_title(title)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    out = out_dir / "rel_error.png"
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print(f"[diffusion compare] wrote {out}")


def plot_uniform_error(data, out_dir: Path) -> None:
    """
    For pure diffusion on a closed network the steady state is a uniform
    mass density along arc length. Skip this plot if drift/flux is present.
    """
    # Heuristic: if flux is present the uniform steady state does not hold.
    # Bail out quietly.
    from experiments.dumux_tracer.config_loader import load_settings, load_metric_config

    cfg = load_settings()
    mg_cfg = load_metric_config()
    has_drift = (getattr(cfg, "flux_scale", 0.0) != 0.0) or (getattr(mg_cfg, "velocity_scale", 0.0) != 0.0)
    if has_drift:
        return

    total_length = float(data["cell_lengths"].sum())
    total_mass0 = float(data["total_du"][0])
    uniform_mass_density = total_mass0 / total_length
    target = uniform_mass_density * data["cell_lengths"][None, :]

    rel_du = np.linalg.norm(data["mass_du"] - target, axis=1) / max(total_mass0, 1e-20)
    rel_mg = np.linalg.norm(data["mass_mg"] - target, axis=1) / max(total_mass0, 1e-20)

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(data["times_du"], rel_du, label="DuMuX", lw=2)
    ax.plot(data["times_mg"], rel_mg, label="Metric", lw=2, ls="--")
    ax.set_xlabel("time [s]")
    ax.set_ylabel("relative L2 to uniform mass")
    ax.set_yscale("log")
    ax.set_title("Convergence to steady state (diffusion)")
    ax.grid(True, alpha=0.3)
    ax.legend()
    fig.tight_layout()
    out = out_dir / "uniform_convergence.png"
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print(f"[diffusion compare] wrote {out}")


def plot_total_mass(data, out_dir: Path) -> None:
    mode = "advection-diffusion" if getattr(settings, "flux_scale", 0.0) != 0.0 else "diffusion"
    fig, ax = plt.subplots(figsize=(6, 4))
    tlen = min(data["times_du"].shape[0], data["total_du"].shape[0])
    mlen = min(data["times_mg"].shape[0], data["total_mg"].shape[0])
    ax.plot(data["times_du"][:tlen], (data["total_du"][:tlen] / data["total_du"][0]), label="DuMuX", lw=2)
    # Align metric times to DuMuX grid for plotting aesthetics
    ax.plot(data["times_mg"][:mlen], (data["total_mg"][:mlen] / data["total_mg"][0]), label="Metric", lw=2, ls="--")
    ax.set_xlabel("time [s]")
    ax.set_ylabel("total mass (normalized)")
    ax.set_title(f"Mass conservation ({mode})")
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    out = out_dir / "total_mass.png"
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print(f"[diffusion compare] wrote {out}")


def gif_edge_profiles(data, out_dir: Path, k: int = 12, max_frames: int = 80) -> None:
    mode = "advection" if getattr(settings, "flux_scale", 0.0) != 0.0 else "diffusion"
    # Select top edges by DuMuX final mass (original edges)
    final_mass = data["edge_mass_du"][-1]
    top_edges = np.argsort(-final_mass)[:k]
    cell_to_edge = data["cell_to_edge"]
    cell_start = data["cell_start"]
    cell_lengths = data["cell_lengths"]
    tracer_du = data["tracer_du"]
    tracer_mg = data["tracer_mg"]
    times_du = data["times_du"]
    pairs = match_times(data["times_mg"], data["times_du"], tol=compare_config.time_tolerance)

    # y-limit using percentile for readability
    vals = []
    for edge_id in top_edges:
        mask = cell_to_edge == edge_id
        if not np.any(mask):
            continue
        vals.append(tracer_du[:, mask])
        vals.append(tracer_mg[:, mask])
    if vals:
        all_vals = np.concatenate([v.reshape(-1) for v in vals])
        ymax = np.percentile(all_vals, 99) * 1.1
    else:
        ymax = 1.0

    frame_indices = _downsample_indices(len(pairs), max_frames)
    frames = []
    for frame_idx in frame_indices:
        m_idx, d_idx = pairs[frame_idx]
        rows, cols = 3, 4
        fig, axes = plt.subplots(rows, cols, figsize=(12, 9), sharex=False)
        axes = axes.ravel()
        total_mass = float(data["total_du"][d_idx])
        for ax, edge_id in zip(axes, top_edges):
            mask = cell_to_edge == edge_id
            if not np.any(mask):
                continue
            order = np.argsort(cell_start[mask])
            seg_lengths = cell_lengths[mask][order]
            bounds = np.concatenate([[0.0], np.cumsum(seg_lengths)])
            vals_du = tracer_du[d_idx, mask][order]
            vals_mg = tracer_mg[m_idx, mask][order]
            vals_du_plot = np.concatenate([vals_du, [vals_du[-1]]])
            vals_mg_plot = np.concatenate([vals_mg, [vals_mg[-1]]])
            ax.step(bounds, vals_du_plot, where="post", color="tab:blue", lw=2, label="DuMuX")
            ax.step(bounds, vals_mg_plot, where="post", color="tab:orange", lw=2, linestyle="--", label="Metric")
            ax.set_ylim(0, ymax * 1.05)
            ax.set_ylabel(f"edge {edge_id}")
            ax.legend(fontsize=8)
        for ax in axes:
            ax.set_xlabel("arc length [m]")
        fig.suptitle(f"{mode} | t = {times_du[d_idx]:.2f} s | total mass = {total_mass:.3e} mol")
        fig.tight_layout()
        fig.canvas.draw()
        image = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
        frames.append(image)
        plt.close(fig)

    out = out_dir / "edge_profiles_overlay.gif"
    out.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out, frames, fps=20)
    print(f"[diffusion compare] wrote {out} ({len(frames)} frames)")


def _gif_network(tracer, times, cell_edges, cell_points, total_mass, out: Path, max_frames: int = 80, title_prefix: str = ""):
    frames_idx = _downsample_indices(tracer.shape[0], max_frames)
    vmax = np.percentile(tracer, 99) if tracer.size else 1.0
    frames = []
    xy = cell_points[:, :2]
    for idx in frames_idx:
        fig, ax = plt.subplots(figsize=(6, 6))
        vals = tracer[idx]
        tmass = float(total_mass[idx])
        segments = []
        colors = []
        for (u, v), val in zip(cell_edges, vals):
            segments.append([[xy[u, 0], xy[u, 1]], [xy[v, 0], xy[v, 1]]])
            colors.append(val)
        lc = matplotlib.collections.LineCollection(
            segments,
            cmap="inferno",
            norm=matplotlib.colors.Normalize(vmin=0.0, vmax=vmax if vmax > 0 else 1.0),
            linewidths=2.0,
        )
        lc.set_array(np.asarray(colors))
        ax.add_collection(lc)
        ax.autoscale()
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        cbar = fig.colorbar(lc, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("tracer (mole fraction)")
        ax.set_title(f"{title_prefix} t = {times[idx]:.2f} s | total mass = {tmass:.3e} mol")
        fig.tight_layout()
        fig.canvas.draw()
        image = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
        frames.append(image)
        plt.close(fig)

    out.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(out, frames, fps=10)
    print(f"[diffusion compare] wrote {out} ({len(frames)} frames)")


def gif_networks(data, out_dir: Path, max_frames: int = 80) -> None:
    mode = "advection" if getattr(settings, "flux_scale", 0.0) != 0.0 else "diffusion"
    _gif_network(
        tracer=data["tracer_du"],
        times=data["times_du"],
        cell_edges=data["cell_edges"],
        cell_points=data["cell_points"],
        total_mass=data["total_du"],
        out=out_dir / "network_fvm.gif",
        max_frames=max_frames,
        title_prefix=f"FVM {mode} |",
    )
    _gif_network(
        tracer=data["tracer_mg"],
        times=data["times_mg"],
        cell_edges=data["cell_edges"],
        cell_points=data["cell_points"],
        total_mass=data["total_mg"],
        out=out_dir / "network_mg.gif",
        max_frames=max_frames,
        title_prefix=f"MG {mode} |",
    )


def main() -> None:
    data = load_data()
    out_dir = settings.plot_dir
    out_dir.mkdir(parents=True, exist_ok=True)
    plot_errors(data, out_dir)
    plot_uniform_error(data, out_dir)
    plot_total_mass(data, out_dir)
    gif_edge_profiles(data, out_dir)
    gif_networks(data, out_dir)


if __name__ == "__main__":
    main()
