"""
Load the DuMuX embedded network geometry and derive metric-graph parameters.

This module parses `network.dgf` (pressures + radii), computes edge-wise
velocities via Poiseuille flow, and builds vertex transition tables compatible
with the CUDA Langevin sampler.
"""

from dataclasses import dataclass
from pathlib import Path

import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()


@dataclass
class DumuxNetwork:
    points: np.ndarray
    edges: np.ndarray
    radii: np.ndarray
    velocities: np.ndarray
    pressures: np.ndarray
    edge_lengths: np.ndarray
    drift_coeffs: np.ndarray
    jump_weights: np.ndarray
    vertex_edge_offsets: np.ndarray
    vertex_edge_indices: np.ndarray
    vertex_edge_orientations: np.ndarray
    vertex_edge_cumweights: np.ndarray


def _refine_edges(
    points: np.ndarray,
    pressures: np.ndarray,
    edges: np.ndarray,
    radii: np.ndarray,
    velocities: np.ndarray,
    segments: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Uniformly subdivide each edge into `segments` smaller edges."""
    if segments <= 1:
        return points, pressures, edges, radii, velocities

    new_points = points.tolist()
    new_pressures = pressures.tolist()
    new_edges: list[list[int]] = []
    new_radii: list[float] = []
    new_velocities: list[float] = []

    for (u, v), r, vel in zip(edges, radii, velocities):
        start = np.asarray(points[u])
        end = np.asarray(points[v])
        p_start = float(pressures[u])
        p_end = float(pressures[v])
        prev_idx = u
        for s in range(1, segments):
            t = s / segments
            coords = (1.0 - t) * start + t * end
            pressure = (1.0 - t) * p_start + t * p_end
            idx = len(new_points)
            new_points.append(coords.tolist())
            new_pressures.append(pressure)
            new_edges.append([prev_idx, idx])
            new_radii.append(r)
            new_velocities.append(vel)
            prev_idx = idx
        new_edges.append([prev_idx, v])
        new_radii.append(r)
        new_velocities.append(vel)

    return (
        np.asarray(new_points, dtype=float),
        np.asarray(new_pressures, dtype=float),
        np.asarray(new_edges, dtype=int),
        np.asarray(new_radii, dtype=float),
        np.asarray(new_velocities, dtype=float),
    )

def _parse_dgf(dgf_path: Path) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Parse the DuMuX `network.dgf` file (pressures + radii)."""
    lines = Path(dgf_path).read_text().splitlines()
    state = None
    coords: list[list[float]] = []
    pressures: list[float] = []
    edges: list[list[int]] = []
    radii: list[float] = []
    velocities: list[float] = []

    for line in lines:
        if not line or line.startswith("%"):
            continue
        stripped = line.strip()
        if stripped == "Vertex":
            state = "vertex"
            continue
        if stripped == "SIMPLEX":
            state = "edge_header"
            continue
        if stripped == "#":
            state = None
            continue
        if state == "vertex":
            if stripped.startswith("parameters"):
                continue
            parts = stripped.split()
            coords.append([float(parts[0]), float(parts[1]), float(parts[2])])
            pressures.append(float(parts[3]))
        elif state == "edge_header":
            if stripped.startswith("parameters"):
                state = "edge"
            continue
        elif state == "edge":
            parts = stripped.split()
            edges.append([int(parts[0]), int(parts[1])])
            radii.append(float(parts[2]))
            if len(parts) >= 4:
                velocities.append(float(parts[3]))
            else:
                velocities.append(0.0)

    return (
        np.asarray(coords, dtype=float),
        np.asarray(pressures, dtype=float),
        np.asarray(edges, dtype=int),
        np.asarray(radii, dtype=float),
        np.asarray(velocities, dtype=float),
    )


def _build_vertex_transitions(
    num_vertices: int, edges: np.ndarray, weights: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Create CSR-like tables for vertex-edge transitions."""
    adjacency: list[list[tuple[int, int, float]]] = [[] for _ in range(num_vertices)]
    for edge_id, (u, v) in enumerate(edges):
        w = float(weights[edge_id])
        adjacency[u].append((edge_id, 0, w))
        adjacency[v].append((edge_id, 1, w))

    offsets = [0]
    indices: list[int] = []
    orientations: list[int] = []
    cumweights: list[float] = []

    for neighbors in adjacency:
        offsets.append(offsets[-1] + len(neighbors))
        if not neighbors:
            continue
        total = sum(n[2] for n in neighbors)
        running = 0.0
        for edge_id, orientation, weight in neighbors:
            prob = weight / total if total > 0 else 1.0 / len(neighbors)
            running += prob
            running = min(running, 1.0)
            indices.append(edge_id)
            orientations.append(orientation)
            cumweights.append(running)

    return (
        np.asarray(offsets, dtype=np.int32),
        np.asarray(indices, dtype=np.int32),
        np.asarray(orientations, dtype=np.int32),
        np.asarray(cumweights, dtype=np.float32),
    )


def load_dumux_network(
    dgf_path: Path,
    viscosity: float = 1e-3,
    weight_power: float = 4.0,
    refine_segments: int = 1,
    weight_by_length: bool = True,
) -> DumuxNetwork:
    """Load geometry + physics from DuMuX `network.dgf`."""
    points_raw, pressures_raw, edges_raw, radii_raw, velocities_raw = _parse_dgf(dgf_path)
    points, pressures, edges, radii, velocities = _refine_edges(
        points_raw, pressures_raw, edges_raw, radii_raw, velocities_raw, refine_segments
    )
    edge_vec = points[edges[:, 1]] - points[edges[:, 0]]
    edge_lengths = np.linalg.norm(edge_vec, axis=1)
    edge_lengths = np.maximum(edge_lengths, 1e-12)
    drift_coeffs = velocities.astype(float)
    if weight_by_length:
        drift_coeffs = drift_coeffs / edge_lengths
    edge_weights = np.asarray(radii, dtype=float) ** weight_power
    jump_weights = edge_weights / edge_weights.sum()
    vertex_edge_offsets, vertex_edge_indices, vertex_edge_orientations, vertex_edge_cumweights = _build_vertex_transitions(
        num_vertices=points.shape[0],
        edges=edges,
        weights=edge_weights,
    )
    return DumuxNetwork(
        points=points,
        edges=edges,
        radii=radii,
        velocities=drift_coeffs,
        pressures=pressures,
        edge_lengths=edge_lengths,
        drift_coeffs=drift_coeffs,
        jump_weights=jump_weights.astype(np.float32),
        vertex_edge_offsets=vertex_edge_offsets,
        vertex_edge_indices=vertex_edge_indices,
        vertex_edge_orientations=vertex_edge_orientations,
        vertex_edge_cumweights=vertex_edge_cumweights,
    )
