"""
Run the CUDA Langevin sampler on the original DuMuX network (diffusion-only),
while histogramming on the finer DuMuX cell partition.
"""

import sys
import time
import os
from pathlib import Path

import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
sys.path.append(str(REPO_ROOT / "main" / "langevin-gpu" / "python"))

from experiments.dumux_tracer.config_loader import load_settings, load_metric_config
settings = load_settings()
config = load_metric_config()
from experiments.dumux_tracer.network import _build_vertex_transitions
from langevin_simulator import LangevinSimulator, STEPS_PER_KERNEL


def build_bin_metadata(cell_lengths, cell_start, cell_to_edge, num_edges):
    edge_bin_counts = np.zeros(num_edges, dtype=np.int32)
    bin_to_cell: list[int] = []
    edge_bin_widths = np.zeros(num_edges, dtype=np.float32)
    uniform_ok = True
    nonuniform_edges = 0
    for e in range(num_edges):
        idx = np.nonzero(cell_to_edge == e)[0]
        if idx.size == 0:
            raise RuntimeError(f"Edge {e} has no cells; cannot histogram")
        order = idx[np.argsort(cell_start[idx])]
        edge_bin_counts[e] = order.size
        bin_to_cell.extend(order.tolist())
        lengths = cell_lengths[order]
        width0 = float(lengths[0])
        # Accept uniform bins except possibly the last bin being in [width0, 2*width0].
        if lengths.size > 1:
            main = lengths[:-1]
            last = lengths[-1]
            main_ok = np.allclose(main, width0, rtol=1e-6, atol=1e-12)
            last_ok = (last >= width0 * (1 - 1e-6)) and (last <= width0 * 2 * (1 + 1e-6))
            if not (main_ok and last_ok):
                uniform_ok = False
                nonuniform_edges += 1
        else:
            # single-bin edge is trivially "uniform"
            pass
        edge_bin_widths[e] = width0
    bin_offsets = np.zeros(num_edges + 1, dtype=np.int32)
    bin_offsets[1:] = np.cumsum(edge_bin_counts, dtype=np.int32)
    bin_to_cell = np.asarray(bin_to_cell, dtype=np.int32)
    bin_lengths_flat = cell_lengths[bin_to_cell].astype(np.float32)
    # Debug sanity check
    if bin_to_cell.size != cell_lengths.size:
        raise RuntimeError(f"bin_to_cell size {bin_to_cell.size} != num cells {cell_lengths.size}")
    # Verify per-edge offsets cover the right cells and lengths sum to edge length
    for e in range(num_edges):
        s, t = bin_offsets[e], bin_offsets[e + 1]
        cells = bin_to_cell[s:t]
        if not np.all(cell_to_edge[cells] == e):
            raise RuntimeError(f"bin_to_cell mapping mismatch on edge {e}")
        length_sum = float(bin_lengths_flat[s:t].sum())
        orig_len = float(cell_lengths[cells].sum())
        if not np.isfinite(length_sum) or length_sum <= 0:
            raise RuntimeError(f"edge {e} has nonpositive bin length sum {length_sum}")
        if abs(length_sum - orig_len) > 1e-9 * max(1.0, orig_len):
            raise RuntimeError(f"edge {e} bin lengths ({length_sum}) != sum cell lengths ({orig_len})")
        # Optional: report tiny edges for debugging
        if e < 3:  # first few edges for inspection
            starts = cell_start[cells]
            print(f"[dbg edge {e}] bins={cells.size} len_sum={length_sum:.3e} edge_len={orig_len:.3e}")
            print(f"  starts(min/max)={starts.min():.3e}/{starts.max():.3e}")
    if not uniform_ok:
        print(f"[histogram] uniform bins rejected on {nonuniform_edges} edges (last-bin rule)")
    else:
        print("[histogram] using uniform-bin histogram kernel (last bin allowed up to 2x width)")
    return bin_offsets, edge_bin_counts, bin_to_cell, bin_lengths_flat, edge_bin_widths, uniform_ok


def run() -> None:
    cfg = config
    input_npz = np.load(settings.dumux_input_npz)
    flux_scale = float(np.asarray(input_npz.get("flux_scale", 0.0)))
    orig_points = np.asarray(input_npz["orig_points"], dtype=float)
    orig_edges = np.asarray(input_npz["orig_edges"], dtype=int)
    orig_radii = np.asarray(input_npz["orig_radii"], dtype=float)
    orig_edge_lengths = np.asarray(input_npz["orig_edge_lengths"], dtype=float)
    cell_lengths = np.asarray(input_npz["cell_lengths"], dtype=float)
    cell_to_edge = np.asarray(input_npz["cell_to_edge"], dtype=int)
    cell_start = np.asarray(input_npz["cell_start"], dtype=float)
    cell_velocities = np.asarray(input_npz.get("cell_velocities", np.zeros_like(cell_lengths)), dtype=float)
    cell_radii = np.asarray(input_npz.get("cell_radii", np.ones_like(cell_lengths)), dtype=float)

    num_edges = orig_edges.shape[0]
    num_cells = cell_lengths.shape[0]

    bin_offsets, edge_bin_counts, bin_to_cell, bin_widths_flat, edge_bin_widths, uniform_ok = build_bin_metadata(
        cell_lengths, cell_start, cell_to_edge, num_edges
    )
    total_bins = int(bin_offsets[-1])

    rng = np.random.default_rng(cfg.rng_seed)

    # Jump probabilities based on r^4.
    weights = np.power(orig_radii, cfg.weight_power)
    jump_weights = weights / weights.sum()

    # Drift per edge: average cell velocity along that edge; scale by cfg.velocity_scale.
    edge_len_sum = np.bincount(cell_to_edge, weights=cell_lengths, minlength=num_edges)
    edge_vel_sum = np.bincount(cell_to_edge, weights=cell_velocities * cell_lengths, minlength=num_edges)
    edge_vel = np.divide(edge_vel_sum, np.maximum(edge_len_sum, 1e-20))
    drift_coeffs = (edge_vel * cfg.velocity_scale).astype(np.float32)

    # Initial particle placement: sample cells proportional to tracer mass, then map to edge/position.
    init_weights_cells = np.ones(num_cells, dtype=np.float64)
    blob_edge_env = os.environ.get("DUMUX_INITIAL_BLOB_EDGE")
    if blob_edge_env is not None:
        try:
            blob_edge = int(blob_edge_env)
            mask = cell_to_edge == blob_edge
            if mask.any():
                init_weights_cells = np.zeros(num_cells, dtype=np.float64)
                init_weights_cells[mask] = 1.0 / mask.sum()
                print(f"[dumux metric graph] initializing all mass on edge {blob_edge} (cells={mask.sum()})")
        except ValueError:
            print(f"[dumux metric graph] invalid DUMUX_INITIAL_BLOB_EDGE={blob_edge_env}, ignoring")
    elif cfg.initial_from_dumux and cfg.tracer_npz_path.exists():
        data = np.load(cfg.tracer_npz_path)
        tracer_all = np.asarray(data["tracer"], dtype=float)
        if tracer_all.ndim == 2 and tracer_all.shape[1] == num_cells:
            tracer0 = tracer_all[min(cfg.initial_snapshot_idx, tracer_all.shape[0] - 1)]
            radius = np.asarray(data.get("cell_radii", []), dtype=float)
            if radius.size == tracer0.size:
                area = np.pi * radius * radius
            else:
                area = np.ones_like(tracer0)
            molar_density = 1000.0 / 0.018  # consistent with params.input
            weights_cells = tracer0 * area * cell_lengths * molar_density
            total_w = weights_cells.sum()
            if total_w > 0:
                init_weights_cells = weights_cells / total_w
        else:
            print(
                "[dumux metric graph] initial_from_dumux requested but tracer shape mismatches; "
                "falling back to uniform cell weights"
            )

    cell_indices = rng.choice(
        num_cells,
        size=cfg.num_particles,
        p=init_weights_cells / init_weights_cells.sum(),
    )
    edge_indices = cell_to_edge[cell_indices].astype(np.int32)
    positions = (
        cell_start[cell_indices] + rng.random(cfg.num_particles) * cell_lengths[cell_indices]
    ).astype(np.float32)

    steps = (
        cfg.steps
        if cfg.steps is not None
        else int(np.ceil(cfg.time_horizon / (cfg.dt * STEPS_PER_KERNEL)))
    )

    ve_offsets, ve_indices, ve_orients, ve_cumweights = _build_vertex_transitions(
        num_vertices=orig_points.shape[0],
        edges=orig_edges,
        weights=weights,
    )

    sim = LangevinSimulator(
        num_particles=cfg.num_particles,
        num_edges=num_edges,
        edge_lengths=orig_edge_lengths,
        jump_weights=jump_weights,
        drift_coeffs=drift_coeffs,
        potential_type="linear",
        backend="cuda",
        edge_vertices=orig_edges,
        vertex_edge_offsets=ve_offsets,
        vertex_edge_indices=ve_indices,
        vertex_edge_orientations=ve_orients,
        vertex_edge_cumweights=ve_cumweights,
    )
    sim.upload_initial_state(edge_indices.astype(np.int32), positions)

    normalizing_factors = cfg.num_particles * bin_widths_flat
    interval = max(1, cfg.record_interval)
    snapshots: list[np.ndarray] = []  # per-cell densities
    times: list[float] = []
    step_times: list[float] = []
    hist_times: list[float] = []
    log_interval = max(1, steps // 20)

    def maybe_validate(hist_flat: np.ndarray, label: str) -> None:
        if not cfg.validate_hist:
            return
        snap_idx = len(snapshots)
        if cfg.validate_hist_every > 0 and (snap_idx % cfg.validate_hist_every) != 0:
            return
        ref = sim.compute_histograms_numpy(bin_offsets, bin_widths_flat)
        max_diff = float(np.abs(ref - hist_flat).max())
        if max_diff != 0.0:
            idx = int(np.argmax(np.abs(ref - hist_flat)))
            ref_val = float(ref[idx])
            gpu_val = float(hist_flat[idx])
            print(
                f"[validate] mismatch at {label}: max diff {max_diff} "
                f"(bin {idx}, ref={ref_val}, gpu={gpu_val})"
            )
            raise RuntimeError("histogram validation failed")
        # Cross-check per-edge counts against particle counts.
        edges_host, *_ = sim.get_state()
        edge_counts = np.bincount(edges_host, minlength=num_edges)
        hist_counts = np.array(
            [hist_flat[bin_offsets[e] : bin_offsets[e + 1]].sum() for e in range(num_edges)],
            dtype=np.float64,
        )
        edge_diff = float(np.abs(edge_counts - hist_counts).max())
        if edge_diff != 0.0:
            raise RuntimeError(f"edge count mismatch at {label}: max diff {edge_diff}")
        print(f"[dbg] validated histogram ({label}) against numpy fallback")

    def histogram_to_cells(flat_counts: np.ndarray) -> np.ndarray:
        densities_bin = flat_counts / normalizing_factors
        out = np.zeros(num_cells, dtype=np.float32)
        out[bin_to_cell] = densities_bin
        return out

    # Record initial state at t=0
    # Initialize histogram buffers (prefer uniform bins if per-edge widths are constant).
    if uniform_ok:
        sim.set_histogram_bins_uniform(bin_offsets, edge_bin_counts, edge_bin_widths)
        hist_fn = sim.compute_histograms_cached
    else:
        # Fall back to variable-length kernel with caching.
        sim.compute_histograms_variable(bin_offsets, bin_widths_flat)
        hist_fn = sim.compute_histograms_cached

    hist_start = time.perf_counter()
    hist0 = hist_fn()
    maybe_validate(hist0, "t=0")
    densities0 = histogram_to_cells(hist0)
    print(f"[dbg] initial hist sum={hist0.sum()} per-bin max={hist0.max()} min={hist0.min()}")
    hist_times.append(time.perf_counter() - hist_start)
    snapshots.append(densities0.astype(np.float32))
    times.append(0.0)

    total_start = time.perf_counter()
    for step in range(steps):
        step_start = time.perf_counter()
        sim.multi_step(base_dt=cfg.dt, sigma=cfg.sigma)
        step_times.append(time.perf_counter() - step_start)
        if ((step + 1) % interval == 0) or (step == steps - 1):
            hist_start = time.perf_counter()
            hist = hist_fn()
            maybe_validate(hist, f"step {step}")
            if step % log_interval == 0 or step == steps - 1:
                print(f"[dbg] step {step}: hist sum={hist.sum()} max={hist.max()} min={hist.min()}")
            densities = histogram_to_cells(hist)
            hist_times.append(time.perf_counter() - hist_start)
            snapshots.append(densities.astype(np.float32))
            times.append((step + 1) * STEPS_PER_KERNEL * cfg.dt)
    total_wall = time.perf_counter() - total_start

    # Bounce diagnostics (per-particle bounce counters) distilled to summary stats to avoid huge outputs.
    bounces_mean = bounces_median = bounces_max = bounces_min = bounces_p90 = bounces_p99 = 0.0
    bounce_instances_mean = bounce_instances_median = bounce_instances_max = 0.0
    bounce_hist = None
    total_substeps = steps * STEPS_PER_KERNEL if steps and STEPS_PER_KERNEL else 0
    try:
        bounces, bounce_instances = sim.get_bounces()
        if bounces.size:
            bounces_mean = float(np.mean(bounces))
            bounces_median = float(np.median(bounces))
            bounces_max = int(np.max(bounces))
            bounces_min = int(np.min(bounces))
            bounces_p90 = float(np.percentile(bounces, 90))
            bounces_p99 = float(np.percentile(bounces, 99))
            # Histogram over average bounces per sub-step (what we care about for stability).
            if total_substeps > 0:
                per_substep = bounces.astype(np.float64) / float(total_substeps)
                hist_counts, hist_edges = np.histogram(per_substep, bins=100, range=(0.0, max(per_substep.max(), 1e-6)))
                bounce_hist = (hist_counts.astype(np.int64), hist_edges.astype(np.float64))
        if bounce_instances.size:
            bounce_instances_mean = float(np.mean(bounce_instances))
            bounce_instances_median = float(np.median(bounce_instances))
            bounce_instances_max = int(np.max(bounce_instances))
        print(
            f"[bounce stats] mean={bounces_mean:.2f}, median={bounces_median:.2f}, "
            f"p90={bounces_p90:.2f}, p99={bounces_p99:.2f}, max={bounces_max}, min={bounces_min}"
        )
    except Exception as e:  # pragma: no cover - diagnostic only
        print(f"[bounce stats] failed to collect: {e}")

    # Bounce rate per physical sub-step (STEPS_PER_KERNEL * base_dt) and per kernel call.
    bounce_rate_per_substep = 0.0
    bounce_rate_per_kernel = 0.0
    if steps > 0 and STEPS_PER_KERNEL > 0:
        total_substeps = steps * STEPS_PER_KERNEL
        bounce_rate_per_substep = bounces_mean / float(total_substeps)
        bounce_rate_per_kernel = bounces_mean / float(steps)
        print(
            f"[bounce stats] rate per sub-step: {bounce_rate_per_substep:.4f}, "
            f"per kernel (dt_kernel={STEPS_PER_KERNEL*cfg.dt:.3e}s): {bounce_rate_per_kernel:.2f}"
        )

    cfg.output_path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        cfg.output_path,
        times=np.asarray(times, dtype=np.float64),
        density=np.stack(snapshots),  # (T, num_cells)
        edge_lengths=orig_edge_lengths.astype(np.float32),
        edges=orig_edges.astype(np.int32),
        points=orig_points.astype(np.float32),
        drift_coeffs=drift_coeffs.astype(np.float32),
        jump_weights=jump_weights.astype(np.float32),
        vertex_edge_offsets=ve_offsets,
        vertex_edge_indices=ve_indices,
        vertex_edge_orientations=ve_orients,
        vertex_edge_cumweights=ve_cumweights,
        bin_offsets=bin_offsets,
        bin_counts=edge_bin_counts,
        bin_to_cell=bin_to_cell,
        cell_lengths=cell_lengths.astype(np.float32),
        cell_to_edge=cell_to_edge.astype(np.int32),
        cell_start=cell_start.astype(np.float32),
        cell_radii=cell_radii.astype(np.float32),
        bin_lengths=bin_widths_flat.astype(np.float32),
        step_wall_times=np.asarray(step_times, dtype=np.float64),
        hist_wall_times=np.asarray(hist_times, dtype=np.float64),
        total_wall_time=np.asarray(total_wall, dtype=np.float64),
        num_particles=np.asarray(cfg.num_particles, dtype=np.int64),
        dt=np.asarray(cfg.dt, dtype=np.float64),
        steps=np.asarray(steps, dtype=np.int64),
        time_horizon=np.asarray(cfg.time_horizon, dtype=np.float64),
        velocity_scale=np.asarray(cfg.velocity_scale, dtype=np.float64),
        record_interval=np.asarray(interval, dtype=np.int64),
        refine_segments=np.asarray(cfg.refine_segments, dtype=np.int64),
        weight_power=np.asarray(cfg.weight_power, dtype=np.float64),
        decay_rate=np.asarray(cfg.decay_rate, dtype=np.float64),
        flux_scale=np.asarray(flux_scale, dtype=np.float64),
        initial_from_dumux=np.asarray(cfg.initial_from_dumux, dtype=np.int64),
        initial_snapshot_idx=np.asarray(cfg.initial_snapshot_idx, dtype=np.int64),
        total_bins=np.asarray(total_bins, dtype=np.int64),
        bounces_mean=np.asarray(bounces_mean, dtype=np.float64),
        bounces_median=np.asarray(bounces_median, dtype=np.float64),
        bounces_max=np.asarray(bounces_max, dtype=np.int64),
        bounces_min=np.asarray(bounces_min, dtype=np.int64),
        bounces_p90=np.asarray(bounces_p90, dtype=np.float64),
        bounces_p99=np.asarray(bounces_p99, dtype=np.float64),
        bounce_instances_mean=np.asarray(bounce_instances_mean, dtype=np.float64),
        bounce_instances_median=np.asarray(bounce_instances_median, dtype=np.float64),
        bounce_instances_max=np.asarray(bounce_instances_max, dtype=np.int64),
        bounce_rate_per_substep=np.asarray(bounce_rate_per_substep, dtype=np.float64),
        bounce_rate_per_kernel=np.asarray(bounce_rate_per_kernel, dtype=np.float64),
        bounce_hist_counts=(bounce_hist[0] if bounce_hist else np.asarray([], dtype=np.int64)),
        bounce_hist_edges=(bounce_hist[1] if bounce_hist else np.asarray([], dtype=np.float64)),
    )
    print(f"[dumux metric graph] Wrote {cfg.output_path}")
    print(f"  snapshots: {len(times)} spanning {times[-1]:.4f} s")
    print(f"  total bins: {total_bins} (cells={num_cells}, edges={num_edges})")


if __name__ == "__main__":
    run()
