#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
build_structures_from_slice_cropped.py

Generate node/edge arrays for GNNs from a (cropped) slice VTK/VTU.
"""

import argparse
from pathlib import Path
from typing import Iterator, Tuple

import numpy as np
import pyvista as pv
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree

# ------------------------------ Morton order ------------------------------ #

def _interleave16(x: np.ndarray) -> np.ndarray:
    x = (x | (x << 8)) & 0x00FF00FF
    x = (x | (x << 4)) & 0x0F0F0F0F
    x = (x | (x << 2)) & 0x33333333
    x = (x | (x << 1)) & 0x55555555
    return x

def morton_permutation(xy: np.ndarray) -> np.ndarray:
    xmin, ymin = xy.min(0)
    xmax, ymax = xy.max(0)
    uv = (xy - (xmin, ymin)) / (np.array([xmax - xmin, ymax - ymin]) + 1e-12)
    uv = np.clip(uv * 65535, 0, 65535).astype(np.uint32)
    u, v = uv[:,0], uv[:,1]
    code = _interleave16(u) | (_interleave16(v) << 1)
    return np.argsort(code)

# ---------------------- Volume reading & node types ----------------------- #

def read_volumetric_multiblock(case_dir: Path, time_index: int) -> pv.MultiBlock:
    foam = case_dir / f"{case_dir.name}.foam"
    foam.touch(exist_ok=True)
    rdr = pv.OpenFOAMReader(str(foam))
    rdr.set_active_time_point(time_index)
    rdr.cell_to_point_creation = True
    rdr.enable_all_patch_arrays()
    return rdr.read()

def flatten_multiblock(mb: pv.MultiBlock, path: str = ""):
    for key in mb.keys():
        block = mb[key]
        name = f"{path}/{key}" if path else key
        if isinstance(block, pv.MultiBlock):
            yield from flatten_multiblock(block, name)
        else:
            yield name, block

def compute_node_types(vol_mb: pv.MultiBlock, tol: float=1e-6) -> Tuple[np.ndarray, np.ndarray]:
    internal = None
    patches = []
    for name, blk in flatten_multiblock(vol_mb):
        if name.endswith("internalMesh"):
            internal = blk
        elif isinstance(blk, (pv.PolyData, pv.UnstructuredGrid)):
            patches.append((name, blk))
    if internal is None:
        raise RuntimeError("internalMesh block not found")
    pts = internal.points
    tree = cKDTree(pts)
    types = np.ones(internal.n_points, dtype=np.int8)  # 1=internal default
    for name, blk in patches:
        d, idx = tree.query(blk.points, k=1)
        sel = idx[d < tol]
        types[sel] = 3 if ("inlet" in name.lower() or "outlet" in name.lower()) else 2
    return types, pts

# ------------------------------- Graph bits -------------------------------- #

def triangulate(poly: pv.PolyData) -> pv.PolyData:
    tri = poly.triangulate()
    tri.cell_data.clear()
    return tri

def compute_undirected_edges_from_tri_mesh(mesh: pv.PolyData):
    """Return directed edges (both ways) with features [dx,dy,dist]."""
    faces = mesh.faces.reshape(-1,4)  # [3, i, j, k]
    pts2  = mesh.points[:,:2]
    edges = set()
    for f in faces:
        _, i, j, k = f
        for u,v in ((i,j),(j,k),(k,i)):
            a,b = int(u), int(v)
            edges.add((min(a,b), max(a,b)))
    senders, receivers = [], []
    for u,v in sorted(edges):
        senders += [u,v]; receivers += [v,u]
    s = np.asarray(senders, dtype=np.int32)
    r = np.asarray(receivers, dtype=np.int32)
    dx = pts2[r,0] - pts2[s,0]
    dy = pts2[r,1] - pts2[s,1]
    dist = np.hypot(dx, dy)
    feats = np.stack([dx, dy, dist], axis=1).astype(np.float32)
    return s, r, feats

def decimate_to_target(poly: pv.PolyData, target_n: int):
    tri = triangulate(poly)
    if tri.n_points <= target_n or target_n <= 0:
        return tri, np.arange(tri.n_points, dtype=np.int64)
    reduction = 1.0 - float(target_n)/tri.n_points
    dec = tri.decimate_pro(
        reduction=reduction,
        preserve_topology=False,
        boundary_vertex_deletion=True,
    )
    # map reduced→tri indices
    tree = cKDTree(tri.points)
    _, idx = tree.query(dec.points, k=1)
    return dec, idx.astype(np.int64)

def compute_o2r_edges(xy_o: np.ndarray, xy_r: np.ndarray, keep_idx: np.ndarray, k: int):
    """Build ori→red edges from removed original nodes to k-NN in reduced set."""
    N0 = xy_o.shape[0]
    keep = np.zeros(N0, dtype=bool)
    keep[keep_idx] = True
    removed = np.nonzero(~keep)[0]
    if removed.size == 0:
        return (np.empty((0,), dtype=np.int32),
                np.empty((0,), dtype=np.int32),
                np.empty((0,3), dtype=np.float32))
    tree = cKDTree(xy_r)
    d, ridx = tree.query(xy_o[removed], k=k)
    if k == 1:
        send = removed.astype(np.int32)
        recv = ridx.astype(np.int32)
    else:
        send = np.repeat(removed[:,None], k, axis=1).ravel().astype(np.int32)
        recv = ridx.ravel().astype(np.int32)
    dx = xy_r[recv,0]-xy_o[send,0]
    dy = xy_r[recv,1]-xy_o[send,1]
    dist = np.hypot(dx, dy)
    feats = np.stack([dx,dy,dist], axis=1).astype(np.float32)
    return send, recv, feats

# ------------------------------- Helpers ----------------------------------- #

def z_to_str(z: float) -> str:
    if float(z).is_integer():
        return f"{int(z)}"
    return f"{z:.3f}".rstrip("0").rstrip(".")

def apply_box_inlet_outlet(nt: np.ndarray, xy: np.ndarray,
                           box: Tuple[float,float,float,float],
                           tol: float) -> np.ndarray:
    """
    Force node_type=3 for points on the crop box boundary within tol.
    box = (xmin, xmax, ymin, ymax)
    """
    xmin, xmax, ymin, ymax = box
    x = xy[:,0]; y = xy[:,1]
    on_left   = np.isclose(x, xmin, atol=tol)
    on_right  = np.isclose(x, xmax, atol=tol)
    on_bottom = np.isclose(y, ymin, atol=tol)
    on_top    = np.isclose(y, ymax, atol=tol)
    boundary_mask = on_left | on_right | on_bottom | on_top
    nt2 = nt.copy()
    nt2[boundary_mask] = 3  # inlet/outlet
    return nt2

# --------------------------------- CLI ------------------------------------ #

def build_parser():
    p = argparse.ArgumentParser(
        description="Build GNN structures from a cropped slice VTK/VTU, forcing crop-box boundaries to inlet/outlet.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    p.add_argument("--data_root",  type=Path, default=Path("data"))
    p.add_argument("--slice_root", type=Path, default=Path("data_sliced_cropped_300k"),
                   help="Root where cropped slices live (case_*/slice_z_<Z>.vtu)")
    p.add_argument("--case",       type=str,  default="case_1")
    p.add_argument("--time",       type=int,  default=0)
    p.add_argument("--z",          type=float,default=40.0)
    p.add_argument("--out_root",   type=Path, default=Path("structures_cropped_300k"))
    p.add_argument("--k_o2r",      type=int,  default=1, help="k-NN for ori→red edges")
    p.add_argument("--target_n_red", type=int, default=48000,
                   help="if >0 and <|ori|, make a second reduced graph of this size")
    p.add_argument("--box", type=float, nargs=4, metavar=("XMIN","XMAX","YMIN","YMAX"),
                   default=(-1000.0, 1000.0, -1000.0, 1000.0),
                   help="Crop box used when slicing")
    p.add_argument("--boundary_tol", type=float, default=1e-3,
                   help="Absolute tolerance for detecting points on the crop box boundary")
    return p

# --------------------------------- main ----------------------------------- #

def main():
    args = build_parser().parse_args()
    case_dir   = (args.data_root/args.case).resolve()
    z_str      = z_to_str(args.z)
    slice_dir  = (args.slice_root/args.case).resolve()
    out_dir    = (args.out_root/f"z_{z_str}").resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    # Prefer new cropped .vtu; fallback to any vtk/vtp with same naming if needed
    candidates = []
    direct_vtu = slice_dir / f"slice_z_{z_str}.vtu"
    if direct_vtu.exists():
        candidates = [direct_vtu]
    else:
        # legacy patterns fallback
        candidates = sorted(list(slice_dir.glob(f"slice_z_{z_str}.*")))
        candidates = [p for p in candidates if p.suffix in (".vtu", ".vtp", ".vtk")]

    if not candidates:
        raise FileNotFoundError(f"No slice found in {slice_dir} for z={z_str}")

    slc_path = candidates[-1]
    print(f"Using slice: {slc_path.name}")

    # read slice
    slc = pv.read(str(slc_path))
    if not isinstance(slc, pv.PolyData):
        slc = slc.cast_to_unstructured_grid().extract_geometry()

    # node types: take from slice if present; otherwise compute from 3D case
    if "node_type" in slc.point_data:
        nt_slice = np.asarray(slc.point_data["node_type"]).astype(np.int8)
        print("Node types loaded from slice point_data['node_type'].")
    else:
        print("Computing node types from 3D case (node_type not found on slice).")
        vol_mb  = read_volumetric_multiblock(case_dir, args.time)
        nt_vol, vol_pts = compute_node_types(vol_mb)
        tree_vol = cKDTree(vol_pts)
        _, idx2vol = tree_vol.query(slc.points, k=1)
        nt_slice = nt_vol[idx2vol].astype(np.int8)

    # enforce inlet/outlet on crop box boundary
    xy = slc.points[:,:2].astype(np.float64)
    nt_slice = apply_box_inlet_outlet(nt_slice, xy, tuple(args.box), args.boundary_tol)

    # coordinates + normalization (kept same behavior)
    max_coord_slice = np.abs(xy).max()
    xy_norm = (xy / (max_coord_slice if max_coord_slice > 0 else 1.0)).astype(np.float32)

    # save raw + normalized coords/types
    np.save(out_dir/"slice_xy.npy",       xy)
    np.save(out_dir/"slice_xy_norm.npy",  xy_norm)
    np.save(out_dir/"slice_node_types.npy", nt_slice)

    # dump node types into a VTK for sanity checks
    slc_vtk = slc.copy()
    slc_vtk.point_data["node_type"] = nt_slice
    slc_vtk.save(str(out_dir/"slice_node_types.vtk"))

    # quick sanity plot
    plt.figure(figsize=(5,5))
    plt.scatter(xy[:,0], xy[:,1], c=nt_slice, cmap="tab10", s=1)
    plt.title("slice node types (cropped)"); plt.axis("equal")
    plt.tight_layout()
    plt.savefig(str(out_dir/"slice_node_types.png"), dpi=150)
    plt.close()

    # -------- ori→ori (undirected): triangulate on the input slice --------
    slc_tri = slc.triangulate()
    s_oo, r_oo, f_oo = compute_undirected_edges_from_tri_mesh(slc_tri)
    max_oo = f_oo[:,2].max() if f_oo.size else 1.0
    np.save(out_dir/"edges_oo_senders.npy",   s_oo)
    np.save(out_dir/"edges_oo_receivers.npy", r_oo)
    np.save(out_dir/"edges_oo_feats.npy",     f_oo)
    np.save(out_dir/"edges_oo_feats_norm.npy", f_oo / max_oo)

    # ---------------- optional second reduction: red graph -----------------
    need_red = (args.target_n_red > 0) and (slc.n_points > args.target_n_red)
    if need_red:
        dec_mesh, dec2orig = decimate_to_target(slc, args.target_n_red)
        red_xy_raw = dec_mesh.points[:,:2].astype(np.float64)

        # Morton reorder reduced nodes for cache-friendly ordering
        perm = morton_permutation(red_xy_raw)
        inv  = np.argsort(perm)
        red_xy_raw = red_xy_raw[perm]
        dec2orig   = dec2orig[perm]

        # reduced node types (from ori with boundary override already applied)
        nt_red = nt_slice[dec2orig]

        # normalized
        max_coord_red = np.abs(red_xy_raw).max()
        red_xy_norm = (red_xy_raw / (max_coord_red if max_coord_red > 0 else 1.0)).astype(np.float32)

        np.save(out_dir/"reduced_xy.npy",      red_xy_raw)
        np.save(out_dir/"reduced_xy_norm.npy", red_xy_norm)
        np.save(out_dir/"reduced_node_types.npy", nt_red)

        # dump reduced node types VTK
        dec_mesh_perm = dec_mesh.copy()
        dec_mesh_perm.points = dec_mesh_perm.points[perm]
        dec_mesh_perm.point_data["node_type"] = nt_red
        dec_mesh_perm.save(str(out_dir/"reduced_node_types.vtk"))

        # red→red edges
        mesh_perm = dec_mesh_perm.triangulate()
        s_rr, r_rr, f_rr = compute_undirected_edges_from_tri_mesh(mesh_perm)
        # reindex (perm already applied to points before triangulation)
        s_rr = inv[s_rr]; r_rr = inv[r_rr]
        max_rr = f_rr[:,2].max() if f_rr.size else 1.0
        np.save(out_dir/"edges_rr_senders.npy",   s_rr)
        np.save(out_dir/"edges_rr_receivers.npy", r_rr)
        np.save(out_dir/"edges_rr_feats.npy",     f_rr)
        np.save(out_dir/"edges_rr_feats_norm.npy",f_rr / max_rr)

        # ori→red and red→ori edges
        s_o2r, r_o2r, f_o2r = compute_o2r_edges(
            xy_o=xy, xy_r=red_xy_raw, keep_idx=dec2orig, k=args.k_o2r
        )
        max_o2r = f_o2r[:,2].max() if f_o2r.size else 1.0
        np.save(out_dir/"edges_o2r_senders.npy",   s_o2r)
        np.save(out_dir/"edges_o2r_receivers.npy", r_o2r)
        np.save(out_dir/"edges_o2r_feats.npy",     f_o2r)
        np.save(out_dir/"edges_o2r_feats_norm.npy",f_o2r / max_o2r)

        s_r2o = r_o2r
        r_r2o = s_o2r
        f_r2o = np.stack([-f_o2r[:,0], -f_o2r[:,1], f_o2r[:,2]], axis=1).astype(np.float32)
        max_r2o = f_r2o[:,2].max() if f_r2o.size else 1.0
        np.save(out_dir/"edges_r2o_senders.npy",   s_r2o)
        np.save(out_dir/"edges_r2o_receivers.npy", r_r2o)
        np.save(out_dir/"edges_r2o_feats.npy",     f_r2o)
        np.save(out_dir/"edges_r2o_feats_norm.npy",f_r2o / max_r2o)

        print(f"✔ Wrote ori (~{slc.n_points:,}) and red (~{red_xy_raw.shape[0]:,}) structures to {out_dir}/")
    else:
        print(f"✔ Wrote ori (~{slc.n_points:,}) structures to {out_dir}/ (no second reduction requested)")

if __name__ == "__main__":
    main()
