
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Compute normalization parameters
"""

import argparse
from pathlib import Path
import numpy as np
import pyvista as pv
import xarray as xr
import re
from tqdm import tqdm

EDGE_SETS = ["o2o", "o2r", "r2r", "r2o"]
EDGE_FILE = {
    "o2o": "edges_oo_feats.npy",
    "o2r": "edges_o2r_feats.npy",
    "r2r": "edges_rr_feats.npy",
    "r2o": "edges_r2o_feats.npy",
}
FEAT_NAMES = ["dx", "dy", "d"]

def discover_cases(root: Path):
    pat = re.compile(r"case_(\d+)$")
    out = []
    for p in sorted(root.iterdir()):
        if p.is_dir():
            m = pat.match(p.name)
            if m:
                out.append((int(m.group(1)), p))
    out.sort(key=lambda t: t[0])
    return out

def load_Uxy(vtp_file: Path):
    slc = pv.read(str(vtp_file))
    U = np.asarray(slc.point_data["U"])  # (N,3)
    return U[:, :2].astype(np.float64)   # Ux,Uy

def online_stats_start(nc: int):
    return {
        "sum": np.zeros(nc, dtype=np.float64),
        "sq":  np.zeros(nc, dtype=np.float64),
        "cnt": 0,
        "min": np.full(nc, np.inf, dtype=np.float64),
        "max": np.full(nc, -np.inf, dtype=np.float64),
    }

def online_stats_update(acc, X):
    acc["sum"] += X.sum(axis=0)
    acc["sq"]  += (X**2).sum(axis=0)
    acc["cnt"] += X.shape[0]
    acc["min"]  = np.minimum(acc["min"], X.min(axis=0))
    acc["max"]  = np.maximum(acc["max"], X.max(axis=0))

def finalize_mean_std(acc):
    if acc["cnt"] == 0:
        raise RuntimeError("No samples found for the requested split.")
    mean = acc["sum"] / acc["cnt"]
    var  = acc["sq"] / acc["cnt"] - mean**2
    std  = np.sqrt(np.maximum(var, 1e-12))
    return mean, std

def print_block(title, stats_dict):
    print(f"\n--- {title} ---")
    for k, v in stats_dict.items():
        print(f"{k}: {np.array(v).round(6)}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--slice_root', type=str, required=True)
    ap.add_argument('--structures_root', type=str, required=True)
    ap.add_argument('--train_z', nargs="+", type=int, required=True)
    ap.add_argument('--fluid_only', action='store_true')
    ap.add_argument('--out', type=str, required=True)
    ap.add_argument('--angle_stride', type=int, default=5,
                    help="Use only cases whose (angle - angle_offset) % angle_stride == 0.")
    ap.add_argument('--angle_offset', type=int, default=1,
                    help="Anchor for stride selection (1..360). Use 1 to select 1,1+stride,...")
    args = ap.parse_args()

    slice_root = Path(args.slice_root)
    structures_root = Path(args.structures_root)
    # Filter cases so we only include those that the real dataset will load.
    # This mirrors GraphDataset._get_triplets(range(1,361, angle_stride)) anchored at angle_offset.
    all_cases = discover_cases(slice_root)  # list[(angle, path)]
    cases = [(ang, p) for (ang, p) in all_cases
             if ((ang - int(args.angle_offset)) % int(args.angle_stride)) == 0]

    # Optional fluid-only mask by z
    nt_cache = {}
    if args.fluid_only:
        for z in args.train_z:
            nt = np.load(structures_root / f"z_{z}" / "slice_node_types.npy", mmap_mode='r')
            nt_cache[z] = (nt != 2)

    # ===== 1) Velocity stats (Ux,Uy) =====
    vel_acc = online_stats_start(2)
    for ang, case_dir in tqdm(cases, desc="UxUy: angles"):
        for z in args.train_z:
            vtp = case_dir / f"slice_z_{z}.vtu"
            if not vtp.exists():
                continue
            U = load_Uxy(vtp)  # (N,2)
            if args.fluid_only:
                U = U[nt_cache[z]]
            online_stats_update(vel_acc, U)

    vel_mean, vel_std = finalize_mean_std(vel_acc)
    vel_min, vel_max  = vel_acc["min"], vel_acc["max"]
    # For max-abs scaling
    vel_max_abs = np.maximum(np.abs(vel_min), np.abs(vel_max))

    print_block("Velocity TRAIN stats (raw)", {
        "mean": vel_mean, "std": vel_std, "min": vel_min, "max": vel_max
    })

    # ===== 2) Coordinates & Edge features (from structures per z) =====
    # We aggregate across all requested z (and structure is not per-angle).
    # 2a) Coordinates
    coords_mean = np.zeros((2, 2), dtype=np.float64)  # [coord_set, xy]
    coords_std  = np.zeros((2, 2), dtype=np.float64)
    coords_min  = np.full((2, 2), np.inf, dtype=np.float64)
    coords_max  = np.full((2, 2), -np.inf, dtype=np.float64)
    # For z-score/minmax computation we pool all z slices
    coord_sets = [("original", "slice_xy.npy"), ("reduced", "reduced_xy.npy")]
    coord_accs = [online_stats_start(2), online_stats_start(2)]

    # 2b) Edge features
    # Accumulate per set across z
    edge_accs = {s: online_stats_start(3) for s in EDGE_SETS}
    dmax_per_set = {s: 0.0 for s in EDGE_SETS}

    for z in tqdm(args.train_z, desc="Structures per-z"):
        zdir = structures_root / f"z_{int(z)}"

        # coords
        for i, (_, fname) in enumerate(coord_sets):
            A = np.load(zdir / fname, mmap_mode='r').astype(np.float64)  # (N,2)
            online_stats_update(coord_accs[i], A)

        # edges
        for s in EDGE_SETS:
            feats = np.load(zdir / EDGE_FILE[s], mmap_mode='r').astype(np.float64)  # (E,3)
            online_stats_update(edge_accs[s], feats)
            dmax_per_set[s] = max(dmax_per_set[s], float(feats[:, 2].max(initial=0.0)))

    # finalize coords
    for i in range(2):
        m, s = finalize_mean_std(coord_accs[i])
        coords_mean[i] = m
        coords_std[i]  = s
        coords_min[i]  = coord_accs[i]["min"]
        coords_max[i]  = coord_accs[i]["max"]

    # finalize edges
    edges_mean = np.zeros((4, 3), dtype=np.float64)
    edges_std  = np.zeros((4, 3), dtype=np.float64)
    edges_min  = np.zeros((4, 3), dtype=np.float64)
    edges_max  = np.zeros((4, 3), dtype=np.float64)
    for ei, s in enumerate(EDGE_SETS):
        m, sd = finalize_mean_std(edge_accs[s])
        edges_mean[ei] = m
        edges_std[ei]  = sd
        edges_min[ei]  = edge_accs[s]["min"]
        edges_max[ei]  = edge_accs[s]["max"]

    # ===== Verification prints =====
    def zscore(X, mean, std): return (X - mean) / np.maximum(std, 1e-12)
    def minmax(X, vmin, vmax): return 2 * (X - vmin) / np.maximum(vmax - vmin, 1e-12) - 1
    def maxabs(X, vmax_abs):   return X / np.maximum(vmax_abs, 1e-12)

    # Velocity checks
    # Re-walk quickly (one pass) to compute post stats
    post_acc = lambda c: online_stats_start(c)
    acc_zs = post_acc(2); acc_mm = post_acc(2); acc_ma = post_acc(2)
    for ang, case_dir in cases:
        for z in args.train_z:
            f = case_dir / f"slice_z_{z}.vtu"
            if not f.exists(): continue
            U = load_Uxy(f)
            if args.fluid_only: U = U[nt_cache[z]]
            Uz = zscore(U, vel_mean, vel_std); online_stats_update(acc_zs, Uz)
            Um = minmax(U, vel_min, vel_max);  online_stats_update(acc_mm, Um)
            Ua = maxabs(U, vel_max_abs);       online_stats_update(acc_ma, Ua)
    m_z, s_z = finalize_mean_std(acc_zs)
    m_m, s_m = finalize_mean_std(acc_mm)
    m_a, s_a = finalize_mean_std(acc_ma)
    print_block("Velocity POST z-score", {"mean": m_z, "std": s_z,
                                          "min": acc_zs["min"], "max": acc_zs["max"]})
    print_block("Velocity POST min-max", {"mean": m_m, "std": s_m,
                                          "min": acc_mm["min"], "max": acc_mm["max"]})
    print_block("Velocity POST max-abs", {"mean": m_a, "std": s_a,
                                          "min": acc_ma["min"], "max": acc_ma["max"]})
    # pooled sigma_data over both channels of the min-max normalized data
    second_moment_per_ch = acc_mm["sq"] / acc_mm["cnt"]   # E[x^2] per channel
    sigma_data = float(np.sqrt(np.mean(second_moment_per_ch)))  # pooled scalar
    print(f"sigma_data (min-max normalized, pooled Ux+Uy): {sigma_data:.6f}")
    # sigma_data for max-abs normalized data (the one we recommend)
    second_moment_per_ch = acc_ma["sq"] / acc_ma["cnt"]
    sigma_per_channel = np.sqrt(np.maximum(second_moment_per_ch, 0.0))
    sigma_data = float(np.sqrt(np.mean(second_moment_per_ch)))
    print(f"sigma_data (max-abs normalized, pooled Ux+Uy): {sigma_data:.6f}")


    # Coordinates checks
    for i, (cname, fname) in enumerate(coord_sets):
        z_acc = post_acc(2); mm_acc = post_acc(2); div_acc = post_acc(2)
        for z in args.train_z:
            A = np.load(structures_root / f"z_{int(z)}" / fname, mmap_mode='r').astype(np.float64)
            Az = zscore(A, coords_mean[i], coords_std[i]); online_stats_update(z_acc, Az)
            Am = minmax(A, coords_min[i], coords_max[i]); online_stats_update(mm_acc, Am)
            Ad = A / 1000.0;                               online_stats_update(div_acc, Ad)
        mz, sz = finalize_mean_std(z_acc); mm, sm = finalize_mean_std(mm_acc); md, sd = finalize_mean_std(div_acc)
        print_block(f"Coordinates '{cname}' POST z-score", {"mean": mz, "std": sz, "min": z_acc["min"], "max": z_acc["max"]})
        print_block(f"Coordinates '{cname}' POST min-max", {"mean": mm, "std": sm, "min": mm_acc["min"], "max": mm_acc["max"]})
        print_block(f"Coordinates '{cname}' POST /1000",    {"mean": md, "std": sd, "min": div_acc["min"], "max": div_acc["max"]})

    # Edge features checks
    for ei, s in enumerate(EDGE_SETS):
        z_acc = post_acc(3); mm_acc = post_acc(3); dmax_acc = post_acc(3)
        for z in args.train_z:
            feats = np.load(structures_root / f"z_{int(z)}" / EDGE_FILE[s], mmap_mode='r').astype(np.float64)
            Fz = zscore(feats, edges_mean[ei], edges_std[ei]); online_stats_update(z_acc, Fz)
            Fm = minmax(feats, edges_min[ei], edges_max[ei]); online_stats_update(mm_acc, Fm)
            # divide entire 3-vector by dmax(set)
            dm = max(dmax_per_set[s], 1e-12)
            Fd = feats / dm;                                 online_stats_update(dmax_acc, Fd)
        mz, sz = finalize_mean_std(z_acc); mm, sm = finalize_mean_std(mm_acc); md, sd = finalize_mean_std(dmax_acc)
        print_block(f"Edges '{s}' POST z-score", {"mean": mz, "std": sz, "min": z_acc["min"], "max": z_acc["max"]})
        print_block(f"Edges '{s}' POST min-max", {"mean": mm, "std": sm, "min": mm_acc["min"], "max": mm_acc["max"]})
        print_block(f"Edges '{s}' POST /dmax",   {"mean": md, "std": sd, "min": dmax_acc["min"], "max": dmax_acc["max"]})

    # ===== Save all parameters to NetCDF =====
    vel_mean_std = xr.DataArray(
        np.stack([vel_mean, vel_std], 0),  # (2,2)
        dims=("stat","channel"),
        coords={"stat":["mean","std"], "channel":["Ux","Uy"]}
    )
    vel_min_max = xr.DataArray(
        np.stack([vel_min, vel_max], 0),   # (2,2)
        dims=("bound","channel"),
        coords={"bound":["min","max"], "channel":["Ux","Uy"]}
    )

    coords_mean_std = xr.DataArray(
        np.stack([coords_mean, coords_std], 1),  # (2,2,2) [coord_set, stat, xy]
        dims=("coord_set","stat","xy"),
        coords={"coord_set":["original","reduced"], "stat":["mean","std"], "xy":["x","y"]}
    )
    coords_min_max = xr.DataArray(
        np.stack([coords_min, coords_max], 1),   # (2,2,2)
        dims=("coord_set","bound","xy"),
        coords={"coord_set":["original","reduced"], "bound":["min","max"], "xy":["x","y"]}
    )

    edges_mean_std = xr.DataArray(
        np.stack([edges_mean, edges_std], 1),    # (4,2,3) [edge_set, stat, feat]
        dims=("edge_set","stat","feat"),
        coords={"edge_set":["o2o","o2r","r2r","r2o"], "stat":["mean","std"], "feat":["dx","dy","d"]}
    )
    edges_min_max = xr.DataArray(
        np.stack([edges_min, edges_max], 1),     # (4,2,3)
        dims=("edge_set","bound","feat"),
        coords={"edge_set":["o2o","o2r","r2r","r2o"], "bound":["min","max"], "feat":["dx","dy","d"]}
    )

    edges_dmax = xr.DataArray(
        np.array([dmax_per_set[s] for s in ["o2o","o2r","r2r","r2o"]], dtype=np.float64),
        dims=("edge_set",),
        coords={"edge_set":["o2o","o2r","r2r","r2o"]}
    )

    ds = xr.Dataset(
        data_vars={
            "velocity_mean_std": vel_mean_std,
            "velocity_min_max":  vel_min_max,
            "velocity_max_abs":  (("channel",), vel_max_abs.astype(np.float64)),
            "velocity_sigma_data": ((), np.float64(sigma_data)),
            "velocity_sigma_per_channel": (("channel",), sigma_per_channel.astype(np.float64)),
            "coords_mean_std":   coords_mean_std,
            "coords_min_max":    coords_min_max,
            "edges_mean_std":    edges_mean_std,
            "edges_min_max":     edges_min_max,
            "edges_dmax":        edges_dmax,
        },
        attrs={"note":"Computed from training split only","fluid_only": bool(args.fluid_only)}
    )

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    ds.to_netcdf(out_path)
    print(f"\n✔ Saved normalization stats to {out_path}")


if __name__ == "__main__":
    main()
