
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
graph_dataset.py (updated)

Adds configurable normalization for:
  - velocity:  zscore | minmax | none
  - coords:    div1000 | zscore | minmax | none
  - edges:     dmax | zscore | minmax | none

"""

from pathlib import Path
from typing import Dict, Sequence, List, Any, Optional, Iterable

import numpy as np
import xarray as xr
import pyvista as pv
from collections import defaultdict
import math
from collections import deque

# ------------- normalization helpers -------------

EDGE_SETS = ["o2o", "o2r", "r2r", "r2o"]
FEAT_NAMES = ["dx","dy","d"]

def _select_minmax(da):
    # Accept either dim name and fall back to positional
    for dim in ("bound", "stat"):
        if dim in da.dims:
            return (
                da.sel({dim: "min"}).values.astype(np.float32),
                da.sel({dim: "max"}).values.astype(np.float32),
            )
    arr = da.values
    if arr.shape[0] != 2:
        raise ValueError(f"Expected first dimension size 2 for min/max, got {arr.shape}")
    return arr[0].astype(np.float32), arr[1].astype(np.float32)

def _load_all_norm(path: str):
    ds = xr.open_dataset(path)
    out = {}

    # velocity
    out["vel_mean"] = ds["velocity_mean_std"].sel(stat="mean").values.astype(np.float32)  # (2,)
    out["vel_std"]  = (ds["velocity_mean_std"].sel(stat="std").values.astype(np.float32) + 1e-12)
    vm = ds["velocity_min_max"]
    out["vel_min"], out["vel_max"] = _select_minmax(vm)
    if "velocity_max_abs" in ds:
        out["vel_max_abs"] = ds["velocity_max_abs"].values.astype(np.float32)


    # coords
    # shapes: (coord_set=['original','reduced'], stat, xy)
    out["coords_mean"] = ds["coords_mean_std"].sel(stat="mean").values.astype(np.float32)  # (2,2)
    out["coords_std"]  = (ds["coords_mean_std"].sel(stat="std").values.astype(np.float32) + 1e-12)
    cm = ds["coords_min_max"]
    out["coords_min"], out["coords_max"] = _select_minmax(cm)


    # edges
    # shapes: (edge_set=4, feat=3)
    out["edges_mean"] = ds["edges_mean_std"].sel(stat="mean").values.astype(np.float32)    # (4,3)
    out["edges_std"]  = (ds["edges_mean_std"].sel(stat="std").values.astype(np.float32) + 1e-12)
    em = ds["edges_min_max"]
    out["edges_min"], out["edges_max"] = _select_minmax(em)
    out["edges_dmax"] = ds["edges_dmax"].values.astype(np.float32)                         # (4,)

    return out

def _zscore(X, mean, std): return (X - mean) / std
def _minmax(X, vmin, vmax): return 2 * (X - vmin) / np.maximum(vmax - vmin, 1e-12) - 1
def _maxabs(X, vabs):       return X / np.maximum(vabs, 1e-12)

def _apply_vel_norm(U_xy: np.ndarray, stats, mode: str) -> np.ndarray:
    if mode == "zscore":
        return _zscore(U_xy, stats["vel_mean"], stats["vel_std"])
    if mode == "minmax":
        return _minmax(U_xy, stats["vel_min"], stats["vel_max"])
    if mode == "maxabs":
        return _maxabs(U_xy, stats["vel_max_abs"])
    return U_xy

def _apply_coord_norm(A: np.ndarray, stats, which: int, mode: str) -> np.ndarray:
    # which: 0 original, 1 reduced
    if mode == "div1000":
        return A / 1000.0
    if mode == "zscore":
        return _zscore(A, stats["coords_mean"][which], stats["coords_std"][which])
    if mode == "minmax":
        return _minmax(A, stats["coords_min"][which], stats["coords_max"][which])
    return A

def _apply_edge_norm(F: np.ndarray, stats, set_index: int, mode: str) -> np.ndarray:
    if mode == "dmax":
        dmax = max(float(stats["edges_dmax"][set_index]), 1e-12)
        return F / dmax
    if mode == "zscore":
        return _zscore(F, stats["edges_mean"][set_index], stats["edges_std"][set_index])
    if mode == "minmax":
        return _minmax(F, stats["edges_min"][set_index], stats["edges_max"][set_index])
    return F

# ------------- dataset -------------

class GraphDataset(Sequence):
    """
    Returns dicts with keys:
        - target_inputs:  (N_o, 2) float32
        - angle_deg:      int16
        - graph_structures: dict of numpy arrays (coords, edges, senders/receivers, node_types)
        - z:              float32
    """
    def __init__(self,
                 slice_root: str = "data_sliced_cropped_300k",
                 norm_stats_nc: str = "normalization_cropped_300k_test/normalization_stats_train.nc",
                 graph_structures: str = "structures_cropped_300k",
                 is_training: bool = True,
                 *,
                 vel_norm: str = "minmax",          # "zscore" | "minmax" | "maxabs" | "none"
                 coord_norm: str = "div1000",       # "div1000" | "zscore" | "minmax" | "none"
                 edge_norm: str = "dmax",           # "dmax" | "zscore" | "minmax" | "none"
                 fixed_angle: Optional[int] = None,
                 fixed_z: Optional[int] = None,
                 shuffle_triplets: bool = True,
                 angle_stride: int = 1):
        self._angle_stride = max(1, int(angle_stride))
        self.slice_root = Path(slice_root)
        self._fixed_angle = fixed_angle
        self._fixed_z = fixed_z
        self._shuffle_triplets = shuffle_triplets

        # Choose z set
        if fixed_z is not None:
            self.available_z = [int(fixed_z)]
        else:
            self.available_z = [15, 20, 28, 45] if is_training else [35, 40]
            # self.available_z = [28] if is_training else [28]

        # Triplets
        base_triplets = self._get_triplets()
        self.triplets = []
        rng = np.random.default_rng(302714)
        for z in self.available_z:
            block = []
            for t in base_triplets:
                block.append({
                    "guiding": (f"{t['guiding'][0]}/slice_z_{z}.vtu",
                                f"{t['guiding'][1]}/slice_z_{z}.vtu"),
                    "target":  f"{t['target']}/slice_z_{z}.vtu",
                    "z": z,
                })
            if self._shuffle_triplets:
                rng.shuffle(block)
            self.triplets.extend(block)

        # Normalization stats
        self._stats = _load_all_norm(norm_stats_nc)
        self._vel_mode   = vel_norm
        self._coord_mode = coord_norm
        self._edge_mode  = edge_norm

        # Graph structure layout
        self.graph_structures_base_path = Path(graph_structures)
        self.graph_structures_files = {
            "original_coordinates": "slice_xy.npy",
            "reduced_coordinates":  "reduced_xy.npy",
            "node_types":           "slice_node_types.npy",
            "o2o_senders":          "edges_oo_senders.npy",
            "o2o_receivers":        "edges_oo_receivers.npy",
            "o2o_features":         "edges_oo_feats.npy",
            "o2r_senders":          "edges_o2r_senders.npy",
            "o2r_receivers":        "edges_o2r_receivers.npy",
            "o2r_features":         "edges_o2r_feats.npy",
            "r2r_senders":          "edges_rr_senders.npy",
            "r2r_receivers":        "edges_rr_receivers.npy",
            "r2r_features":         "edges_rr_feats.npy",
            "r2o_senders":          "edges_r2o_senders.npy",
            "r2o_receivers":        "edges_r2o_receivers.npy",
            "r2o_features":         "edges_r2o_feats.npy",
        }

        # Per-z cache
        self._graph_cache: Dict[float, Dict[str, np.ndarray]] = {}

    def __len__(self) -> int:
        return len(self.triplets)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        triplet = self.triplets[idx]

        # Target velocity
        target_f = self.slice_root / triplet["target"]
        slc = pv.read(str(target_f))
        U = np.asarray(slc.point_data["U"], dtype=np.float32)[:, :2]  # (N,2)
        U = _apply_vel_norm(U, self._stats, self._vel_mode)           # normalized as configured

        # Case angle
        case_number = int(triplet["target"].split("/")[0].split("_")[-1])

        # Graph for this z
        z_value = float(triplet["z"])
        graph_data = self._load_graph_structures_for_z(z_value)  # already normalized per config

        return {
            "target_inputs": U.astype(np.float32),
            "angle_deg": np.int16(case_number),
            "graph_structures": graph_data,
            "z": np.float32(z_value),
        }

    # ---------- helpers ----------

    def _load_graph_structures_for_z(self, z_value: float) -> Dict[str, np.ndarray]:
        if z_value in self._graph_cache:
            return self._graph_cache[z_value]

        z_dir = self.graph_structures_base_path / f"z_{int(z_value)}"
        g = {}
        for key, fname in self.graph_structures_files.items():
            arr_path = z_dir / fname
            if not arr_path.exists():
                raise FileNotFoundError(f"Graph structure file {arr_path} not found.")
            g[key] = np.load(arr_path, mmap_mode='r')

        # Apply coordinate normalization
        g["original_coordinates"] = _apply_coord_norm(
            g["original_coordinates"].astype(np.float32), self._stats, which=0, mode=self._coord_mode)
        g["reduced_coordinates"] = _apply_coord_norm(
            g["reduced_coordinates"].astype(np.float32),  self._stats, which=1, mode=self._coord_mode)

        # Apply edge normalization per set
        def norm_edges(arr, set_index):
            return _apply_edge_norm(arr.astype(np.float32), self._stats, set_index, self._edge_mode)

        g["o2o_features"] = norm_edges(g["o2o_features"], 0)
        g["o2r_features"] = norm_edges(g["o2r_features"], 1)
        g["r2r_features"] = norm_edges(g["r2r_features"], 2)
        g["r2o_features"] = norm_edges(g["r2o_features"], 3)

        self._graph_cache[z_value] = g
        return g

    def _get_triplets(self) -> List[Dict[str, str]]:
        if self._fixed_angle is not None:
            i = int(self._fixed_angle)
            prev_angle = i - 1 if i > 1 else 360
            next_angle = i + 1 if i < 360 else 1
            return [{"guiding": (f"case_{prev_angle}", f"case_{next_angle}"),
                     "target":  f"case_{i}"}]
        out = []
        for i in range(1, 361, self._angle_stride):
            prev_angle = i - 1 if i > 1 else 360
            next_angle = i + 1 if i < 360 else 1
            out.append({"guiding": (f"case_{prev_angle}", f"case_{next_angle}"),
                        "target":  f"case_{i}"})
        return out

# -------- batching & reshuffling that works with forever(ds) --------

def _collate_same_z(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    z_vals = {float(b["z"]) for b in batch}
    if len(z_vals) != 1:
        raise ValueError(f"Batch contains multiple z values: {z_vals}")
    out: Dict[str, Any] = {}
    keys = list(batch[0].keys())
    for k in keys:
        if k == "graph_structures":
            out[k] = batch[0][k]
        elif k == "z":
            out[k] = batch[0][k]
        else:
            out[k] = np.stack([b[k] for b in batch], axis=0)
    return out

class EpochIterable:
    """
    Iterable that rebuilds (and shuffles) batches on every new iteration.
    Works nicely with:

        def forever(ds):
            while True:
                for b in ds:
                    yield b
    """
    def __init__(self, base_ds: GraphDataset, batch_size: int, shuffle: bool, seed: int, drop_remainder: bool):
        self.base_ds = base_ds
        self.batch_size = int(batch_size)
        self.shuffle = bool(shuffle)
        self.seed = int(seed)
        self.drop_remainder = bool(drop_remainder)
        self._epoch = 0

        # Pre-group indices by z
        self.z_to_indices: Dict[float, List[int]] = defaultdict(list)
        for idx, trip in enumerate(self.base_ds.triplets):
            self.z_to_indices[float(trip["z"])].append(idx)

    # def __iter__(self):
    #     # new rng per epoch to reshuffle
    #     rng = np.random.default_rng(self.seed + self._epoch)
    #     self._epoch += 1

    #     z_keys = list(self.z_to_indices.keys())
    #     if self.shuffle: rng.shuffle(z_keys)

    #     for z in z_keys:
    #         idxs = self.z_to_indices[z][:]
    #         if self.shuffle: rng.shuffle(idxs)
    #         # build batches for this z
    #         for i in range(0, len(idxs), self.batch_size):
    #             chunk = idxs[i:i + self.batch_size]
    #             if self.drop_remainder and len(chunk) < self.batch_size:
    #                 continue
    #             items = [self.base_ds[j] for j in chunk]
    #             yield _collate_same_z(items)
    

    def __iter__(self):
        rng = np.random.default_rng(self.seed + self._epoch)
        self._epoch += 1

        # prepare per-z queues of chunks
        z_keys = list(self.z_to_indices.keys())
        if self.shuffle: rng.shuffle(z_keys)

        queues = []
        for z in z_keys:
            idxs = self.z_to_indices[z][:]
            if self.shuffle: rng.shuffle(idxs)
            chunks = []
            for i in range(0, len(idxs), self.batch_size):
                chunk = idxs[i:i + self.batch_size]
                if not (self.drop_remainder and len(chunk) < self.batch_size):
                    chunks.append(chunk)
            queues.append(deque(chunks))

        # interleave
        active = True
        while active:
            active = False
            for q in queues:
                if q:
                    active = True
                    chunk = q.popleft()
                    items = [self.base_ds[j] for j in chunk]
                    yield _collate_same_z(items)


    # (optional) length = number of full batches (first epoch’s layout)
    def __len__(self):
        total = 0
        for z, idxs in self.z_to_indices.items():
            total += len(idxs) // self.batch_size if self.drop_remainder else math.ceil(len(idxs)/self.batch_size)
        return total

# -------- public convenience makers --------

def make_dataset(*,  # <- use this one for `forever(ds)`
                          batch_size: int = 8,
                          shuffle: bool = True,
                          seed: int = 0,
                          drop_remainder: bool = True,
                          angle_stride: int = 1,
                          **ds_kwargs) -> Iterable[Dict[str, Any]]:
    base_ds = GraphDataset(angle_stride=angle_stride, **ds_kwargs)
    return EpochIterable(base_ds, batch_size, shuffle, seed, drop_remainder)
