"""
Compare DuMuX tracer output with the CUDA metric-graph run on the same network
using the **cell-level** discretization (no edge/cell conflation).
"""

from pathlib import Path

import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())

from experiments.dumux_tracer.config_loader import load_compare_config
config = load_compare_config()

MOLAR_DENSITY = 1000.0 / 0.018  # kg/m^3 divided by tracer molar mass (params.input)


def _cell_mass(tracer: np.ndarray, lengths: np.ndarray, radii: np.ndarray | None) -> np.ndarray:
    """
    Convert tracer mole fraction to moles per cell.

    Args:
        tracer: (time, cells) or (cells,) mole fraction
        lengths: (cells,) cell arc lengths
        radii: (cells,) radii (optional)
    """
    tracer = np.asarray(tracer, dtype=float)
    lengths = np.asarray(lengths, dtype=float)
    radii = np.asarray(radii, dtype=float) if radii is not None else None
    if radii is not None and radii.shape == lengths.shape:
        area = np.pi * radii * radii
        return tracer * MOLAR_DENSITY * area * lengths
    return tracer * lengths


def _align_edges(source_edges: np.ndarray, source_points: np.ndarray, target_edges: np.ndarray, target_points: np.ndarray) -> np.ndarray:
    """
    Return a permutation that reorders `source_edges` to best match `target_edges`
    via midpoint proximity. Falls back to identity if counts differ.
    """
    if source_edges.shape[0] != target_edges.shape[0]:
        return np.arange(source_edges.shape[0], dtype=int)
    if np.array_equal(source_edges, target_edges):
        return np.arange(source_edges.shape[0], dtype=int)

    source_mid = (source_points[source_edges[:, 0]] + source_points[source_edges[:, 1]]) / 2.0
    target_mid = (target_points[target_edges[:, 0]] + target_points[target_edges[:, 1]]) / 2.0
    dist2 = np.sum((source_mid[:, None, :] - target_mid[None, :, :]) ** 2, axis=2)
    mapping = np.argmin(dist2, axis=1)
    return mapping.astype(int)


def load_dumux(path: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    data = np.load(path)
    tracer = np.asarray(data["tracer"], dtype=float)  # (time, cells) mole fraction
    times = np.asarray(data["times"], dtype=float)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    cell_radii = np.asarray(data["cell_radii"], dtype=float) if "cell_radii" in data else None
    cell_edges = np.asarray(data["cell_edges"], dtype=int)
    cell_points = np.asarray(data["cell_points"], dtype=float)

    mass = _cell_mass(tracer, cell_lengths, cell_radii)
    totals = mass.sum(axis=1)
    return times, mass, totals, cell_edges, cell_points, cell_lengths, cell_radii


def load_metric(
    path: Path,
    cell_edges: np.ndarray,
    cell_points: np.ndarray,
    cell_lengths: np.ndarray,
    cell_radii: np.ndarray | None,
    target_total: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    data = np.load(path)
    times = np.asarray(data["times"], dtype=float)
    density = np.asarray(data["density"], dtype=float)  # (time, cells)
    lengths = np.asarray(data.get("cell_lengths", []), dtype=float)
    if lengths.size == 0 or lengths.shape[0] != density.shape[1]:
        # Fallback: try to align via edges if density is edge-based.
        edge_lengths = np.asarray(data.get("edge_lengths", []), dtype=float)
        edges = np.asarray(data["edges"], dtype=int) if "edges" in data else None
        points = np.asarray(data["points"], dtype=float) if "points" in data else None
        num_bins = int(np.asarray(data.get("num_bins", density.shape[-1] if density.ndim > 2 else 1)).item())
        if edges is None or points is None:
            raise RuntimeError("Metric output missing edges/points; cannot align geometries")
        if density.ndim == 2:
            density = density[:, :, None]
        mapping = _align_edges(edges, points, cell_edges, cell_points)
        density = density[:, mapping]
        edge_lengths = edge_lengths[mapping]
        lengths = edge_lengths.repeat(num_bins)
    else:
        lengths = cell_lengths

    # Probability per cell = density * length
    prob = density * lengths[None, :]
    # mass per cell: probability * total_mass
    mass = density * lengths[None, :] * target_total
    totals = mass.sum(axis=1)

    area = np.ones_like(lengths)
    if cell_radii is not None and cell_radii.size == lengths.size:
        area = np.pi * cell_radii * cell_radii
    tracer = mass / (MOLAR_DENSITY * area[None, :] * lengths[None, :])

    return times, mass, totals, tracer


def match_times(metric_times: np.ndarray, dumux_times: np.ndarray, tol: float) -> list[tuple[int, int]]:
    """Match each DuMuX time to the nearest metric time within `tol` seconds."""
    pairs: list[tuple[int, int]] = []
    for d_idx, t in enumerate(dumux_times):
        m_idx = int(np.argmin(np.abs(metric_times - t)))
        if abs(metric_times[m_idx] - t) <= tol:
            pairs.append((m_idx, d_idx))
    return pairs


def main() -> None:
    cfg = config

    dumux_times, dumux_mass, dumux_total, cell_edges, cell_points, cell_lengths, cell_radii = load_dumux(cfg.dumux_npz)
    dumux_tracer = np.asarray(np.load(cfg.dumux_npz)["tracer"], dtype=float)

    metric_times, metric_mass, metric_total, metric_tracer = load_metric(
        cfg.metric_npz,
        cell_edges=cell_edges,
        cell_points=cell_points,
        cell_lengths=cell_lengths,
        cell_radii=cell_radii,
        target_total=float(dumux_total[0] if dumux_total.size else 1.0),
    )

    pairs = match_times(metric_times, dumux_times, tol=cfg.time_tolerance)
    if cfg.min_compare_time > 0:
        pairs = [(m, d) for (m, d) in pairs if dumux_times[d] >= cfg.min_compare_time]
    if not pairs:
        print(
            "[dumux compare] No matching times within tolerance "
            f"{cfg.time_tolerance}s (metric min/max {metric_times.min():.3f}/{metric_times.max():.3f}, "
            f"dumux min/max {dumux_times.min():.3f}/{dumux_times.max():.3f})"
        )
        return

    errors = []
    rel_errors = []
    tracer_errors = []
    tracer_rel_errors = []
    mass_ratio = []
    for m_idx, d_idx in pairs:
        err = np.linalg.norm(metric_mass[m_idx] - dumux_mass[d_idx])
        total = max(dumux_total[d_idx], 1e-20)
        errors.append(err)
        rel_errors.append(err / total)
        mass_ratio.append(metric_total[m_idx] / total)
        # Tracer (mole fraction) comparison
        t_err = np.linalg.norm(metric_tracer[m_idx] - dumux_tracer[d_idx])
        t_norm = max(np.linalg.norm(dumux_tracer[d_idx]), 1e-20)
        tracer_errors.append(t_err)
        tracer_rel_errors.append(t_err / t_norm)

    errors = np.asarray(errors)
    rel_errors = np.asarray(rel_errors)
    tracer_errors = np.asarray(tracer_errors)
    tracer_rel_errors = np.asarray(tracer_rel_errors)
    mass_ratio = np.asarray(mass_ratio)
    print(f"[dumux compare] matched {len(pairs)} snapshots (tol={cfg.time_tolerance}s)")
    print(
        f"  L2 mass error: min={errors.min():.4e} median={np.median(errors):.4e} "
        f"max={errors.max():.4e}"
    )
    print(
        f"  Relative L2 (per total mass): min={rel_errors.min():.4e} "
        f"median={np.median(rel_errors):.4e} max={rel_errors.max():.4e}"
    )
    print(
        f"  Tracer L2 error: min={tracer_errors.min():.4e} median={np.median(tracer_errors):.4e} "
        f"max={tracer_errors.max():.4e}"
    )
    print(
        f"  Relative tracer L2: min={tracer_rel_errors.min():.4e} "
        f"median={np.median(tracer_rel_errors):.4e} max={tracer_rel_errors.max():.4e}"
    )
    print(
        f"  Metric probability mass: min={mass_ratio.min():.4e} "
        f"median={np.median(mass_ratio):.4e} max={mass_ratio.max():.4e} (should be ~1)"
    )


if __name__ == "__main__":
    main()
