"""
Preprocess the DuMuX network.dgf:
- keep only the largest connected component of the original graph
- choose a base cell size so the shortest edge gets at least `min_cells_per_min_edge` cells
- subdivide each original edge into cells with lengths in [cell_size, 2*cell_size]
- write both:
    * a DGF for DuMuX (network_filtered.dgf) using the subdivided cells
    * a compact NPZ (dumux_network_input.npz) containing the original graph and the cell mapping
"""

from pathlib import Path
from typing import Iterable, Tuple

import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.network import _parse_dgf
from experiments.dumux_tracer.config_loader import load_settings

settings = load_settings()

REPO_ROOT = Path(inject_repo_into_sys_path())
SOURCE_DGF = REPO_ROOT / "external" / "dumux" / "examples" / "network_tracer_1d" / "network.dgf"
OUTPUT_DGF = settings.dune_root / "dumux" / "examples" / "network_tracer_1d" / "network_filtered.dgf"
OUTPUT_NPZ = settings.dumux_input_npz


def connected_components(n_nodes: int, edges: np.ndarray) -> list[np.ndarray]:
    adj = [[] for _ in range(n_nodes)]
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)
    visited = np.zeros(n_nodes, dtype=bool)
    comps: list[np.ndarray] = []
    for start in range(n_nodes):
        if visited[start]:
            continue
        stack = [start]
        nodes = []
        visited[start] = True
        while stack:
            u = stack.pop()
            nodes.append(u)
            for nbr in adj[u]:
                if not visited[nbr]:
                    visited[nbr] = True
                    stack.append(nbr)
        comps.append(np.array(nodes, dtype=int))
    return comps


def filter_largest_component(
    points: np.ndarray,
    pressures: np.ndarray,
    edges: np.ndarray,
    radii: np.ndarray,
    velocities: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    comps = connected_components(points.shape[0], edges)
    sizes = [len(c) for c in comps]
    largest_idx = int(np.argmax(sizes))
    keep_nodes = set(comps[largest_idx].tolist())
    node_map = -np.ones(points.shape[0], dtype=int)
    for new_idx, old_idx in enumerate(sorted(keep_nodes)):
        node_map[old_idx] = new_idx
    mask = np.array([(u in keep_nodes) and (v in keep_nodes) for u, v in edges])
    new_edges = []
    new_radii = []
    new_vel = []
    for (u, v), r, vel in zip(edges[mask], radii[mask], velocities[mask]):
        new_edges.append([node_map[u], node_map[v]])
        new_radii.append(r)
        new_vel.append(vel)
    new_edges = np.asarray(new_edges, dtype=int)
    new_points = points[list(sorted(keep_nodes))]
    new_pressures = pressures[list(sorted(keep_nodes))]
    new_radii = np.asarray(new_radii, dtype=float)
    new_vel = np.asarray(new_vel, dtype=float)
    return new_points, new_pressures, new_edges, new_radii, new_vel


def _choose_target_dx(points: np.ndarray, edges: np.ndarray, min_cells_per_min_edge: int) -> float:
    diff = points[edges[:, 1]] - points[edges[:, 0]]
    lengths = np.linalg.norm(diff, axis=1)
    min_len = float(lengths.min())
    return min_len / float(min_cells_per_min_edge)


def subdivide_edges_variable(
    points: np.ndarray,
    pressures: np.ndarray,
    edges: np.ndarray,
    radii: np.ndarray,
    base_dx: float,
    min_cells_per_min_edge: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Split each edge into roughly uniform cells with lengths in [base_dx, 2*base_dx].
    Returns subdivided points/pressures/edges/radii plus per-cell mapping to original edge and
    the start fraction along the edge.
    """
    diff = points[edges[:, 1]] - points[edges[:, 0]]
    lengths = np.linalg.norm(diff, axis=1)
    segments_per_edge = np.maximum(1, np.floor(lengths / base_dx)).astype(int)
    segments_per_edge = np.maximum(segments_per_edge, min_cells_per_min_edge)

    new_points = points.tolist()
    new_pressures = pressures.tolist()
    new_edges: list[list[int]] = []
    new_radii: list[float] = []
    cell_to_edge: list[int] = []
    cell_start: list[float] = []
    segments_per_edge_out: list[int] = []

    for edge_idx, ((u, v), r, segs, L) in enumerate(zip(edges, radii, segments_per_edge, lengths)):
        segs = int(max(1, segs))
        p0, p1 = points[u], points[v]
        pr0, pr1 = pressures[u], pressures[v]
        seg_len = L / segs
        segments_per_edge_out.append(segs)
        prev_idx = u
        for s in range(1, segs):
            alpha = s / segs
            p = (1 - alpha) * p0 + alpha * p1
            pr = (1 - alpha) * pr0 + alpha * pr1
            new_idx = len(new_points)
            new_points.append(p.tolist())
            new_pressures.append(float(pr))
            new_edges.append([prev_idx, new_idx])
            new_radii.append(r)
            cell_to_edge.append(edge_idx)
            cell_start.append((s - 1) * seg_len)
            prev_idx = new_idx
        new_edges.append([prev_idx, v])
        new_radii.append(r)
        cell_to_edge.append(edge_idx)
        cell_start.append((segs - 1) * seg_len)

    return (
        np.asarray(new_points, dtype=float),
        np.asarray(new_pressures, dtype=float),
        np.asarray(new_edges, dtype=int),
        np.asarray(new_radii, dtype=float),
        np.asarray(cell_to_edge, dtype=int),
        np.asarray(cell_start, dtype=float),
        base_dx,
        np.asarray(segments_per_edge_out, dtype=int),
    )


def write_dgf(points: np.ndarray, pressures: np.ndarray, edges: np.ndarray, radii: np.ndarray, velocities: np.ndarray, path: Path) -> None:
    lines = ["DGF", "Vertex", "parameters 1 # pressures"]
    for p, pr in zip(points, pressures):
        lines.append(f"{p[0]:.16e} {p[1]:.16e} {p[2]:.16e} {pr:.16e}")
    lines.append("#")
    lines.append("SIMPLEX")
    lines.append("parameters 2 # radius velocity")
    for (u, v), r, vel in zip(edges, radii, velocities):
        lines.append(f"{u:d} {v:d} {r:.16e} {vel:.16e}")
    lines.append("#")
    path.write_text("\n".join(lines) + "\n")
    print(f"[preprocess] wrote {path} (nodes={points.shape[0]}, edges={edges.shape[0]})")


def _write_input_npz(
    path: Path,
    orig_points: np.ndarray,
    orig_edges: np.ndarray,
    orig_radii: np.ndarray,
    orig_pressures: np.ndarray,
    orig_edge_lengths: np.ndarray,
    cell_points: np.ndarray,
    cell_edges: np.ndarray,
    cell_radii: np.ndarray,
    cell_to_edge: np.ndarray,
    cell_start: np.ndarray,
    cell_velocities: np.ndarray,
    base_dx: float,
    min_cells_per_min_edge: int,
    flux_scale: float,
) -> None:
    cell_lengths = np.linalg.norm(cell_points[cell_edges[:, 1]] - cell_points[cell_edges[:, 0]], axis=1)
    path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        path,
        orig_points=orig_points,
        orig_edges=orig_edges,
        orig_radii=orig_radii,
        orig_pressures=orig_pressures,
        orig_edge_lengths=orig_edge_lengths,
        cell_points=cell_points,
        cell_edges=cell_edges,
        cell_radii=cell_radii,
        cell_lengths=cell_lengths,
        cell_to_edge=cell_to_edge,
        cell_start=cell_start,
        cell_velocities=cell_velocities,
        base_dx=base_dx,
        min_cells_per_min_edge=min_cells_per_min_edge,
        flux_scale=flux_scale,
    )
    print(f"[preprocess] wrote {path} (cells={cell_edges.shape[0]})")


def _divergence_free_flux(points: np.ndarray, edges: np.ndarray, flux_scale: float, radii: np.ndarray) -> np.ndarray:
    """
    Generate a divergence-free flux on the graph (m^3/s) via projection.
    """
    num_vertices = points.shape[0]
    num_edges = edges.shape[0]
    data = []
    rows = []
    cols = []
    for e_idx, (u, v) in enumerate(edges):
        rows.extend([u, v])
        cols.extend([e_idx, e_idx])
        data.extend([1.0, -1.0])
    B = sp.csr_matrix((data, (rows, cols)), shape=(num_vertices, num_edges))

    rng = np.random.default_rng(0)
    flux0 = rng.standard_normal(num_edges) * flux_scale
    div = B @ flux0
    BBT = (B @ B.T).tocsc()
    reg = 1e-12
    BBT = BBT + reg * sp.eye(num_vertices, format="csc")
    x = spla.spsolve(BBT, div)
    flux = flux0 - B.T @ x

    # Scale to target magnitude
    area = np.pi * radii * radii
    vel = flux / np.maximum(area, 1e-20)
    max_vel = np.max(np.abs(vel))
    if max_vel > 0:
        flux *= (flux_scale / max_vel)
    return flux


def main(
    segments: int | None = None,
    target_dx: float | None = settings.target_dx,
    min_cells_per_min_edge: int = settings.min_cells_per_min_edge,
    output_npz: Path = OUTPUT_NPZ,
    flux_scale: float = settings.flux_scale,
    synthetic_velocity: float | None = None,
) -> None:
    points, pressures, edges, radii, velocities = _parse_dgf(SOURCE_DGF)
    points, pressures, edges, radii, velocities = filter_largest_component(points, pressures, edges, radii, velocities)
    base_lengths = np.linalg.norm(points[edges[:, 1]] - points[edges[:, 0]], axis=1)

    if target_dx is None:
        base_dx = _choose_target_dx(points, edges, min_cells_per_min_edge)
        effective_min_cells = min_cells_per_min_edge
    else:
        base_dx = target_dx
        effective_min_cells = 1  # honor explicit dx even if it exceeds edge length

    points_sub, pressures_sub, edges_sub, radii_sub, cell_to_edge, cell_start, base_dx, segments_per_edge = (
        subdivide_edges_variable(
            points,
            pressures,
            edges,
            radii,
            base_dx=base_dx,
            min_cells_per_min_edge=effective_min_cells,
        )
    )

    if synthetic_velocity is not None:
        rng = np.random.default_rng(0)
        # One random sign per *original* edge to keep velocity constant along the edge.
        edge_signs = rng.choice([-1.0, 1.0], size=edges.shape[0])
        edge_velocities = edge_signs * synthetic_velocity
        velocities_sub = edge_velocities[cell_to_edge]
        print(f"[preprocess] using synthetic velocities with |v|={synthetic_velocity:g}")
    else:
        # Divergence-free flux on the subdivided (cell) graph
        flux = _divergence_free_flux(points_sub, edges_sub, flux_scale=flux_scale, radii=radii_sub)
        velocities_sub = flux / (np.pi * np.power(radii_sub, 2))

    sub_lengths = np.linalg.norm(points_sub[edges_sub[:, 1]] - points_sub[edges_sub[:, 0]], axis=1)
    print(
        f"[preprocess] base_dx={base_dx:.3e} m (min edge len={base_lengths.min():.3e}, "
        f"min_cells_per_min_edge={min_cells_per_min_edge})"
    )
    print(
        f"[preprocess] segments stats: min={segments_per_edge.min()}, median={np.median(segments_per_edge):.1f}, "
        f"max={segments_per_edge.max()}, mean={segments_per_edge.mean():.2f}"
    )
    print(
        f"[preprocess] cell length stats: min={sub_lengths.min():.3e}, "
        f"median={np.median(sub_lengths):.3e}, max={sub_lengths.max():.3e}, mean={sub_lengths.mean():.3e}"
    )
    write_dgf(points_sub, pressures_sub, edges_sub, radii_sub, velocities_sub, OUTPUT_DGF)
    _write_input_npz(
        output_npz,
        orig_points=points,
        orig_edges=edges,
        orig_radii=radii,
        orig_pressures=pressures,
        orig_edge_lengths=base_lengths,
        cell_points=points_sub,
        cell_edges=edges_sub,
        cell_radii=radii_sub,
        cell_to_edge=cell_to_edge,
        cell_start=cell_start,
        cell_velocities=velocities_sub,
        base_dx=base_dx,
        min_cells_per_min_edge=min_cells_per_min_edge,
        flux_scale=flux_scale,
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Filter DuMuX network and subdivide edges.")
    parser.add_argument(
        "segments",
        type=int,
        nargs="?",
        default=None,
        help="(deprecated) uniform segments per edge; prefer target_dx/min_cells (defaults come from experiment_settings.py)",
    )
    parser.add_argument(
        "--target-dx",
        type=float,
        default=settings.target_dx,
        help=f"explicit target cell length in meters; if omitted, derived from min edge / min_cells_per_min_edge (default from settings: {settings.target_dx})",
    )
    parser.add_argument(
        "--min-cells-per-min-edge",
        type=int,
        default=settings.min_cells_per_min_edge,
        help="ensure the shortest edge gets at least this many cells (default from settings)",
    )
    parser.add_argument(
        "--flux-scale",
        type=float,
        default=settings.flux_scale,
        help="target characteristic volumetric flux magnitude (m^3/s) used to build a divergence-free field (default from settings)",
    )
    parser.add_argument(
        "--synthetic-velocity",
        type=float,
        default=None,
        help="if set, bypass divergence-free flux and assign a constant-magnitude random-sign velocity per edge (m/s)",
    )
    parser.add_argument(
        "--output-npz",
        type=Path,
        default=OUTPUT_NPZ,
        help="path to write NPZ with original graph + cell mapping (default: data/dumux_network_input.npz)",
    )
    args = parser.parse_args()
    main(
        segments=args.segments,
        target_dx=args.target_dx,
        min_cells_per_min_edge=args.min_cells_per_min_edge,
        output_npz=args.output_npz,
        flux_scale=args.flux_scale,
        synthetic_velocity=args.synthetic_velocity,
    )
