"""
Diagnostic plot: edge-level velocity vs mass for DuMuX and MG (final snapshot).

Outputs under figs/advection/ (or settings.plot_dir):
  - velocity_mass.png : hist of velocities + scatter of velocity vs MG/DU mass ratio
"""

from __future__ import annotations

from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.config_loader import load_settings


def load_data():
    settings = load_settings()
    input_npz = np.load(settings.dumux_input_npz)
    tracer_npz = np.load(settings.dumux_tracer_output)
    metric_npz = np.load(settings.metric_output)

    cell_to_edge = np.asarray(input_npz["cell_to_edge"], dtype=int)
    cell_lengths = np.asarray(input_npz["cell_lengths"], dtype=float)
    cell_radii = np.asarray(input_npz.get("cell_radii", np.ones_like(cell_lengths)), dtype=float)
    cell_vel = np.asarray(input_npz.get("cell_velocities", np.zeros_like(cell_lengths)), dtype=float)
    num_edges = int(np.max(cell_to_edge)) + 1

    # Per-edge velocity: length-weighted average of cell velocities
    edge_len = np.bincount(cell_to_edge, weights=cell_lengths, minlength=num_edges)
    edge_vel = np.divide(
        np.bincount(cell_to_edge, weights=cell_vel * cell_lengths, minlength=num_edges),
        np.maximum(edge_len, 1e-20),
    )

    # Final masses
    area = np.pi * cell_radii * cell_radii
    tracer_final = np.asarray(tracer_npz["tracer"], dtype=float)[-1]
    du_mass_cell = tracer_final * area * cell_lengths
    total0 = float(np.sum(np.asarray(tracer_npz["tracer"], dtype=float)[0] * area * cell_lengths))

    density_final = np.asarray(metric_npz["density"], dtype=float)[-1]
    mg_mass_cell = density_final * cell_lengths * total0

    du_edge = np.bincount(cell_to_edge, weights=du_mass_cell, minlength=num_edges)
    mg_edge = np.bincount(cell_to_edge, weights=mg_mass_cell, minlength=num_edges)

    return edge_vel, du_edge, mg_edge, cell_to_edge, cell_lengths, cell_radii, tracer_final, density_final, area, metric_npz


def plot(edge_vel: np.ndarray, du_edge: np.ndarray, mg_edge: np.ndarray, out_dir: Path) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    # Scatter: velocity vs mass ratio
    ratio = mg_edge / np.maximum(du_edge, 1e-20)
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    axes[0].hist(edge_vel, bins=50, color="C0", alpha=0.8)
    axes[0].axvline(0.0, color="k", linestyle="--", linewidth=1)
    axes[0].set_xlabel("Edge velocity (m/s)")
    axes[0].set_ylabel("Count")
    axes[0].set_title("Velocity distribution")

    sc = axes[1].scatter(edge_vel, ratio, s=12, alpha=0.6, c=np.abs(edge_vel), cmap="viridis")
    axes[1].axhline(1.0, color="k", linestyle="--", linewidth=1)
    axes[1].axvline(0.0, color="k", linestyle="--", linewidth=1)
    axes[1].set_xlabel("Edge velocity (m/s)")
    axes[1].set_ylabel("MG / DuMuX edge mass (final)")
    axes[1].set_title("Mass ratio vs velocity")
    fig.colorbar(sc, ax=axes[1], label="|velocity|")
    fig.tight_layout()
    out = out_dir / "velocity_mass.png"
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print(f"[velocity-mass] wrote {out}")


def plot_profiles(edge_vel, cell_to_edge, cell_lengths, tracer_final, density_final, area, out_dir: Path) -> None:
    # Focus on top-|velocity| edges, plot tracer vs MG mass density profiles.
    num_edges = int(np.max(cell_to_edge)) + 1
    edge_len = np.bincount(cell_to_edge, weights=cell_lengths, minlength=num_edges)
    vel_avg = np.bincount(cell_to_edge, weights=edge_vel[cell_to_edge] * cell_lengths, minlength=num_edges) / np.maximum(edge_len, 1e-20)
    top_edges = np.argsort(-np.abs(vel_avg))[:6]

    fig, axes = plt.subplots(2, 3, figsize=(12, 6), sharey=True)
    axes = axes.ravel()
    for ax, e in zip(axes, top_edges):
        mask = cell_to_edge == e
        if not np.any(mask):
            continue
        lengths = cell_lengths[mask]
        order = np.argsort(lengths.cumsum())
        segs = lengths[order]
        x = np.concatenate([[0.0], np.cumsum(segs)])
        du = tracer_final[mask][order]
        mg = density_final[mask][order] * (area[mask][order] * lengths[order]) / np.maximum(area[mask][order], 1e-20)
        # Extend y to match step plotting
        du_plot = np.concatenate([du, du[-1:]])
        mg_plot = np.concatenate([mg, mg[-1:]])
        ax.step(x, du_plot, where="post", label="DuMuX tracer", lw=2)
        ax.step(x, mg_plot, where="post", label="MG mass dens", lw=2, ls="--")
        ax.set_title(f"edge {e} |v|={abs(vel_avg[e]):.1e}")
        ax.set_xlabel("arc length (m)")
    axes[0].set_ylabel("density / mass density")
    axes[3].set_ylabel("density / mass density")
    axes[0].legend(fontsize=8)
    fig.tight_layout()
    out = out_dir / "velocity_top_edges_profiles.png"
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print(f"[velocity-mass] wrote {out}")


def main() -> None:
    settings = load_settings()
    edge_vel, du_edge, mg_edge, cell_to_edge, cell_lengths, cell_radii, tracer_final, density_final, area, metric_npz = load_data()
    plot(edge_vel, du_edge, mg_edge, settings.plot_dir)
    plot_profiles(edge_vel, cell_to_edge, cell_lengths, tracer_final, density_final, area, settings.plot_dir)


if __name__ == "__main__":
    main()
