"""
Extract 1D network geometry and tracer fields from the DuMuX
`network_tracer_1d` example into a compact `.npz`.

Usage:
    - Edit `experiments/dumux_tracer/config.py` (or set env vars noted there).
    - Run `python -m experiments.dumux_tracer.extract_dumux_tracer`.
"""

from pathlib import Path
from typing import Iterable

import meshio
import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

inject_repo_into_sys_path()

from experiments.dumux_tracer.config import config

try:
    import pyvista as pv
except ImportError:
    pv = None

def discover_network_files(vtk_dir: Path) -> list[Path]:
    """Return sorted list of VTK files containing the 1D network."""
    if not vtk_dir.exists():
        raise FileNotFoundError(f"VTK directory {vtk_dir} does not exist")
    all_vtk = sorted(vtk_dir.glob("*.vtu")) + sorted(vtk_dir.glob("*.vtp"))
    network_files = [
        f
        for f in all_vtk
        if ("network" in f.name.lower()) or ("1d" in f.name.lower())
    ]
    if not network_files:
        network_files = all_vtk
    if not network_files:
        raise FileNotFoundError(f"No .vtu/.vtp files found in {vtk_dir}")
    return sorted(network_files)


def _split_polyline_edges(edges: np.ndarray) -> np.ndarray:
    """Split polyline cells into 2-node edges (vectorized over rows)."""
    if edges.shape[1] <= 2:
        return edges.astype(int)
    segments = np.column_stack([edges[:, :-1].ravel(), edges[:, 1:].ravel()])
    return segments.astype(int)


def _read_mesh(path: Path) -> meshio.Mesh:
    """Load a mesh, handling VTP explicitly when meshio cannot infer it."""
    suffix = path.suffix.lower()
    if suffix == ".vtp":
        if pv is None:
            raise ImportError(
                "pyvista (with vtk) is required to read .vtp network files. "
                "Install via `uv pip install pyvista`."
            )
        poly = pv.read(path)
        if poly.lines.size == 0:
            raise RuntimeError(f"No line cells found in {path}")
        edges = poly.lines.reshape(-1, 3)[:, 1:].astype(int)
        cell_data = {name: [np.asarray(arr)] for name, arr in poly.cell_data.items()}
        return meshio.Mesh(
            points=np.asarray(poly.points),
            cells=[("line", edges)],
            cell_data=cell_data,
        )
    return meshio.read(path)


def extract_geometry(mesh: meshio.Mesh) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
    """
    Extract node coordinates, 1D connectivity (edges), and edge lengths.

    Returns:
        points: (n_points, 3) coordinates
        edges: (n_edges, 2) connectivity
        edge_lengths: (n_edges,) Euclidean lengths
        block_idx: index of the mesh cell block that holds the 1D network
    """
    points = np.asarray(mesh.points, dtype=float)
    line_block_idx = None
    edges = None

    for idx, cells in enumerate(mesh.cells):
        if cells.type in ("line", "line3", "polyline"):
            line_block_idx = idx
            edges = np.asarray(cells.data, dtype=int)
            break

    if line_block_idx is None or edges is None:
        raise RuntimeError("No 1D line / polyline cells found in mesh")

    edges = _split_polyline_edges(edges)
    diff = points[edges[:, 1]] - points[edges[:, 0]]
    edge_lengths = np.linalg.norm(diff, axis=1)
    return points, edges, edge_lengths, line_block_idx


def extract_scalar_cell_field(
    mesh: meshio.Mesh,
    block_idx: int,
    preferred_names: Iterable[str] = (),
) -> tuple[str, np.ndarray]:
    """
    Extract scalar cell field for the given cell block index.

    Tries the preferred names first, then falls back to the first scalar field.
    """
    for name in preferred_names:
        if name in mesh.cell_data:
            arr = np.asarray(mesh.cell_data[name][block_idx])
            if arr.ndim == 1:
                return name, arr.astype(float)
            if arr.ndim == 2 and arr.shape[1] == 1:
                return name, arr[:, 0].astype(float)

    for name, blocks in mesh.cell_data.items():
        if block_idx >= len(blocks):
            continue
        arr = np.asarray(blocks[block_idx])
        if arr.ndim == 1:
            return name, arr.astype(float)
        if arr.ndim == 2 and arr.shape[1] == 1:
            return name, arr[:, 0].astype(float)

    raise RuntimeError(f"Could not find scalar cell field for block index {block_idx}")


def _extract_optional_field(
    mesh: meshio.Mesh,
    block_idx: int,
    preferred_names: Iterable[str],
) -> tuple[str | None, np.ndarray | None]:
    for name in preferred_names:
        if name in mesh.cell_data:
            arr = np.asarray(mesh.cell_data[name][block_idx])
            if arr.ndim == 1:
                return name, arr.astype(float)
            if arr.ndim == 2 and arr.shape[1] == 1:
                return name, arr[:, 0].astype(float)
    return None, None


def _load_times(vtk_dir: Path, num_steps: int, sample_every: int = 1) -> np.ndarray:
    """
    Load physical times from the DuMuX clearance output if available.

    Falls back to integer indices when the text log is missing or shorter
    than the number of VTK files.
    """
    candidates = [vtk_dir / "clearance_tracer_amounts.dat"]
    candidates.extend(sorted(vtk_dir.glob("*_tracer_amounts.dat")))
    for path in candidates:
        if not path.exists():
            continue
        data = np.loadtxt(path)
        if data.ndim == 1:
            # single time value
            if data.size >= 1:
                times = np.asarray([data.item(0)], dtype=float)
            else:
                times = np.array([], dtype=float)
        else:
            if data.shape[1] == 0:
                times = np.array([], dtype=float)
            else:
                times = np.asarray(data[:, 0], dtype=float)
        if sample_every > 1:
            times = times[::sample_every]
        if times.shape[0] >= num_steps:
            return times[:num_steps]
    return np.arange(num_steps, dtype=float)


def main() -> None:
    cfg = config
    input_data = np.load(cfg.input_path)
    orig_points = np.asarray(input_data["orig_points"], dtype=float)
    orig_edges = np.asarray(input_data["orig_edges"], dtype=int)
    orig_radii = np.asarray(input_data["orig_radii"], dtype=float)
    orig_pressures = np.asarray(input_data["orig_pressures"], dtype=float)
    orig_edge_lengths = np.asarray(input_data["orig_edge_lengths"], dtype=float)
    cell_points_input = np.asarray(input_data["cell_points"], dtype=float)
    cell_edges_input = np.asarray(input_data["cell_edges"], dtype=int)
    cell_radii = np.asarray(input_data["cell_radii"], dtype=float)
    cell_velocities = np.asarray(input_data.get("cell_velocities", []), dtype=float)
    cell_to_edge = np.asarray(input_data["cell_to_edge"], dtype=int)
    cell_start = np.asarray(input_data["cell_start"], dtype=float)
    base_dx = float(input_data["base_dx"])
    min_cells_per_min_edge = int(input_data["min_cells_per_min_edge"])

    files = discover_network_files(cfg.vtk_dir)
    if cfg.sample_every > 1:
        files = files[:: cfg.sample_every]
    if cfg.max_steps is not None:
        files = files[: cfg.max_steps]

    print(f"[dumux extract] Loading {len(files)} network VTK files from {cfg.vtk_dir}")
    first_mesh = _read_mesh(files[0])
    points, edges, edge_lengths, block_idx = extract_geometry(first_mesh)

    if edges.shape[0] != cell_edges_input.shape[0]:
        raise RuntimeError(
            f"Cell count mismatch: VTK edges={edges.shape[0]} vs preprocess={cell_edges_input.shape[0]}"
        )

    tracer_name, tracer0 = extract_scalar_cell_field(
        first_mesh,
        block_idx=block_idx,
        preferred_names=("x_B", "tracer", "Tracer", "xb", "X_B"),
    )

    n_cells = tracer0.shape[0]
    tracer_list = [tracer0]
    kept_files = [files[0].name]
    radius_name, radius0 = _extract_optional_field(
        first_mesh,
        block_idx=block_idx,
        preferred_names=("vessel radius (m)", "vessel_radius", "radius"),
    )

    for idx, path in enumerate(files[1:], start=1):
        mesh = _read_mesh(path)
        _, vals = extract_scalar_cell_field(
            mesh,
            block_idx=block_idx,
            preferred_names=(tracer_name,),
        )
        if vals.shape[0] != n_cells:
            print(
                f"[dumux extract] skipping {path.name} due to cell count mismatch "
                f"(expected {n_cells}, got {vals.shape[0]})"
            )
            continue
        tracer_list.append(vals)
        kept_files.append(path.name)
        if radius_name and radius0 is None:
            _, radius_vals = _extract_optional_field(
                mesh,
                block_idx=block_idx,
                preferred_names=(radius_name,),
            )
            radius0 = radius_vals
        if (idx + 1) % 10 == 0 or idx == len(files) - 1:
            print(f"[dumux extract] Loaded {len(tracer_list)}/{len(files)} time steps (kept so far)")

    tracer = np.stack(tracer_list, axis=0)
    times = _load_times(cfg.vtk_dir, len(tracer_list), sample_every=cfg.sample_every)

    cfg.output_path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        cfg.output_path,
        orig_points=orig_points,
        orig_edges=orig_edges,
        orig_radii=orig_radii,
        orig_pressures=orig_pressures,
        orig_edge_lengths=orig_edge_lengths,
        cell_points=cell_points_input,
        cell_edges=cell_edges_input,
        cell_lengths=edge_lengths,
        cell_radii=cell_radii,
        cell_to_edge=cell_to_edge,
        cell_start=cell_start,
        base_dx=base_dx,
        min_cells_per_min_edge=min_cells_per_min_edge,
        tracer=tracer,
        velocities=cell_velocities if cell_velocities.size else np.array([]),
        times=times,
        tracer_name=tracer_name,
        vtk_dir=str(cfg.vtk_dir),
        files=np.array(kept_files),
        radius=radius0 if radius0 is not None else np.array([]),
        radius_name=radius_name if radius_name else "",
        extrusion=np.asarray(cell_points_input[cell_edges_input[:, 1]] - cell_points_input[cell_edges_input[:, 0]]),
    )

    print(f"[dumux extract] Wrote {cfg.output_path}")
    print(f"  orig points:  {orig_points.shape}")
    print(f"  orig edges:   {orig_edges.shape}")
    print(f"  cells:        {cell_edges_input.shape}")
    print(f"  tracer:       {tracer.shape} (time, cell)")
    print(f"  times:        {times.shape}")
    print(f"  field name:   {tracer_name}")


if __name__ == "__main__":
    main()
