#!/usr/bin/env python3
"""
Evaluate GenDA reconstruction under multiple observation sampling strategies, reporting metrics and example plots per strategy and observation count range.
"""
from __future__ import annotations

import argparse
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax.experimental import mesh_utils

from gen_da.gen_da import GenSynth, SamplerConfig, NoiseConfig
from gen_da.denoiser import (
    DenoiserArchitectureConfig,
    NoiseEncoderConfig,
    AngleEncoderConfig,
)

from training.graph_dataset import make_dataset
from training.obs_sampling import GraphCacheManager, expand_mask_hops  # type: ignore

# Reuse metrics from training script
from training.train import _mean_angular_similarity, _graph_ssim_speed

try:
    from training.train import prepare_batch  # ensure identical preprocessing
except Exception:
    from baselines.train_meshgraphnets import prepare_batch  # type: ignore

# Optional deps
try:
    from scipy.interpolate import griddata  # type: ignore
except Exception:
    griddata = None  # type: ignore
try:
    import pyvista as pv  # type: ignore
except Exception:
    pv = None  # type: ignore


# ----------------------------
# Model init / restore
# ----------------------------

def _init_gen_da(example_batch, latent_size: int, seed: int = 0):
    example_inputs, example_forcings, *_ = prepare_batch(example_batch, obs_frac=0.0, obs_neighbor_hops=0)
    n_devices = max(1, jax.device_count())
    mesh = jax.sharding.Mesh(mesh_utils.create_device_mesh((n_devices,)), ('data',))
    den_cfg = DenoiserArchitectureConfig(
        latent_size=latent_size,
    )
    example_graph = example_forcings['graph_structures']
    if isinstance(example_graph, (list, tuple)):
        example_graph = example_graph[0]
    target_channels = int(example_inputs['U_field'].shape[-1])
    guiding_channels = target_channels
    model = GenSynth(
        denoiser_architecture_config=den_cfg,
        sampler_config=SamplerConfig(),
        noise_config=NoiseConfig(),
        noise_encoder_config=NoiseEncoderConfig(),
        angle_encoder_config=AngleEncoderConfig(),
        rngs=nnx.Rngs(seed),
        mesh=mesh,
        example_graph_structures=example_graph,
        target_channels=target_channels,
        guiding_channels=guiding_channels,
    )
    return model


def _restore_nnx_model(model, ckpt_dir: Path):
    import orbax.checkpoint as ocp
    abs_graphdef, abs_rng_state, abs_other_state = nnx.split(model, nnx.RngState, ...)
    ckptr = ocp.PyTreeCheckpointer()
    restored_state = ckptr.restore(str(ckpt_dir), item=abs_other_state)
    nnx.update(model, restored_state)
    model.eval()
    return model


# ----------------------------
# Utility: metrics, coords, rrmse
# ----------------------------

def _evaluate_one_case_metrics(original_data: np.ndarray, predicted_data: np.ndarray, gs: Dict[str, Any]) -> Dict[str, float]:
    diff = predicted_data - original_data
    rmse_error = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean(original_data**2)))
    rel_rmse_error = rmse_error / denom if denom != 0 else 0.0
    cosine = _mean_angular_similarity(original_data, predicted_data)
    senders = np.asarray(gs['o2o_senders']).astype(np.int64)
    receivers = np.asarray(gs['o2o_receivers']).astype(np.int64)
    speed_true = np.linalg.norm(original_data, axis=-1).astype(np.float64)
    speed_pred = np.linalg.norm(predicted_data, axis=-1).astype(np.float64)
    ssim = _graph_ssim_speed(speed_true, speed_pred, senders, receivers, include_self=True)
    return {"rrmse": rel_rmse_error, "cosine": float(cosine), "ssim": float(ssim)}


def _coords_from_graph_struct(gs: Dict[str, Any]) -> np.ndarray:
    if isinstance(gs, (list, tuple)):
        gs = gs[0]
    coords = np.asarray(gs.get('original_coordinates'))
    if coords is None:
        return np.zeros((int(gs.get('num_o_nodes', 0)), 2), dtype=np.float32)
    return coords[:, :2] if coords.shape[1] >= 2 else np.pad(coords, ((0,0),(0,max(0,2 - coords.shape[1]))))[:, :2]


def _rrmse_field(u_true: np.ndarray, u_pred: np.ndarray) -> float:
    diff = (u_pred - u_true).astype(np.float64)
    rmse = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean((u_true.astype(np.float64))**2)))
    return rmse / denom if denom != 0.0 else 0.0


# ----------------------------
# Grid utilities for VTU export
# ----------------------------

def _make_grid(coords_xy: np.ndarray, nx: int = 300, ny: int = 300):
    xmin, ymin = coords_xy.min(axis=0)
    xmax, ymax = coords_xy.max(axis=0)
    gx = np.linspace(xmin, xmax, nx)
    gy = np.linspace(ymin, ymax, ny)
    X, Y = np.meshgrid(gx, gy)
    return X, Y, (xmin, xmax, ymin, ymax)


def _scatter_to_grid(coords_xy: np.ndarray, values: np.ndarray, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    if griddata is None:
        # nearest assignment + simple inpainting
        xmin, xmax = X[0, 0], X[0, -1]
        ymin, ymax = Y[0, 0], Y[-1, 0]
        ny, nx = X.shape
        ix = np.clip(((coords_xy[:, 0] - xmin) / max(xmax - xmin, 1e-12) * (nx - 1)).round().astype(int), 0, nx - 1)
        iy = np.clip(((coords_xy[:, 1] - ymin) / max(ymax - ymin, 1e-12) * (ny - 1)).round().astype(int), 0, ny - 1)
        Z = np.full((ny, nx), np.nan, dtype=np.float32)
        Z[iy, ix] = values.astype(np.float32)
        # inpaint by nearest along rows and columns
        for y in range(ny):
            row = Z[y]
            good = ~np.isnan(row)
            if np.any(good):
                idx = np.where(good)[0]
                for x in range(nx):
                    if np.isnan(row[x]):
                        row[x] = row[idx[np.argmin(np.abs(idx - x))]]
        for x in range(nx):
            col = Z[:, x]
            good = ~np.isnan(col)
            if np.any(good):
                idx = np.where(good)[0]
                for y in range(ny):
                    if np.isnan(col[y]):
                        col[y] = col[idx[np.argmin(np.abs(idx - y))]]
        return Z
    else:
        Z = griddata(coords_xy, values, (X, Y), method='linear')  # type: ignore
        Znn = griddata(coords_xy, values, (X, Y), method='nearest')  # type: ignore
        Z[np.isnan(Z)] = Znn[np.isnan(Z)]
        return Z


def _unstructured_from_rect_grid(X: np.ndarray, Y: np.ndarray):
    """Create a PyVista UnstructuredGrid of quads from rectilinear grid X,Y.
    Returns (grid, ny, nx). Uses QUAD cells so VTU opens in ParaView properly.
    """
    if pv is None:
        raise RuntimeError("PyVista not available")
    ny, nx = X.shape
    pts = np.column_stack([X.ravel(order='C'), Y.ravel(order='C'), np.zeros(nx*ny, dtype=X.dtype)])
    # Build quad cells
    num_cells = (ny - 1) * (nx - 1)
    cells = np.empty(num_cells * 5, dtype=np.int64)  # 4 points + size prefix
    celltypes = np.full(num_cells, getattr(pv, 'CellType').QUAD if hasattr(pv, 'CellType') else 9, dtype=np.uint8)
    idx = 0
    c = 0
    for j in range(ny - 1):
        base = j * nx
        for i in range(nx - 1):
            a = base + i
            b = a + 1
            d = a + nx
            e = d + 1
            cells[idx] = 4; idx += 1
            cells[idx:idx+4] = [a, b, e, d]
            idx += 4
            c += 1
    grid = pv.UnstructuredGrid(cells, celltypes, pts)  # type: ignore
    return grid, ny, nx


# ----------------------------
# Sampling strategies
# ----------------------------

def _match_target_count(mask: np.ndarray, target: int, candidates: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    """Adjust a boolean mask to have exactly `target` trues using `candidates` universe.
    If mask has too many, randomly unset extras; if too few, randomly add from remaining candidates.
    """
    mask = mask.astype(bool).copy()
    current = int(mask.sum())
    cand = np.asarray(candidates)
    if current > target:
        idx = np.flatnonzero(mask)
        drop = rng.choice(idx, size=(current - target), replace=False)
        mask[drop] = False
    elif current < target:
        missing = target - current
        avail = cand[~mask[cand]] if cand.size else np.flatnonzero(~mask)
        if avail.size > 0:
            add = rng.choice(avail, size=min(missing, avail.size), replace=False)
            mask[add] = True
    return mask


def _random_sampling_mask(gc, target_count: int, hops: int, rng: np.random.Generator) -> np.ndarray:
    cand = gc.fluid_cand if gc.fluid_cand.size > 0 else np.arange(gc.coords.shape[0], dtype=np.int32)
    k = min(target_count, cand.size)
    picks = rng.choice(cand, size=k, replace=False)
    base = np.zeros(gc.coords.shape[0], dtype=bool)
    base[picks] = True
    if hops > 0:
        base = expand_mask_hops(base, gc.indptr, gc.indices, hops)
    # match exact count for fairness
    base = _match_target_count(base, target_count, cand, rng)
    return base


def _farthest_point_pairs(coords: np.ndarray, cand: np.ndarray, rng: np.random.Generator) -> Tuple[int, int]:
    """Heuristic: pick a random candidate as start, then choose end as farthest candidate by Euclidean distance."""
    if cand.size == 0:
        return 0, 0
    s = int(rng.choice(cand))
    pts = coords[cand, :2]
    d2 = np.sum((pts - coords[s, :2])**2, axis=1)
    e = int(cand[int(np.argmax(d2))])
    return s, e


def _nearest_indices_along_segment(coords: np.ndarray, a_idx: int, b_idx: int, n_samples: int) -> np.ndarray:
    """Sample n_samples points along segment (A->B) and pick nearest graph nodes.
    Brute force nearest search for robustness (O(N*n_samples))."""
    A = coords[a_idx, :2]
    B = coords[b_idx, :2]
    ts = np.linspace(0.0, 1.0, max(2, n_samples))
    pts = (1 - ts)[:, None] * A[None, :] + ts[:, None] * B[None, :]
    # compute nearest neighbor by brute-force
    diffs = coords[:, :2][None, :, :] - pts[:, None, :]
    d2 = np.sum(diffs**2, axis=-1)  # (n_samples, N)
    idxs = np.argmin(d2, axis=1)
    return np.asarray(idxs, dtype=np.int32)


def _cloud_sampling_mask(gc, target_count: int, hops: int, rng: np.random.Generator, max_iters: int = 1000) -> np.ndarray:
    cand = gc.fluid_cand if gc.fluid_cand.size > 0 else np.arange(gc.coords.shape[0], dtype=np.int32)
    if cand.size == 0:
        return np.zeros(gc.coords.shape[0], dtype=bool)
    mask = np.zeros(gc.coords.shape[0], dtype=bool)
    # Greedy farthest-point seeding to reduce overlap
    chosen: List[int] = []
    # Start from a random seed
    chosen.append(int(rng.choice(cand)))
    for _ in range(max_iters):
        base = np.zeros_like(mask)
        base[np.asarray(chosen, dtype=np.int32)] = True
        if hops > 0:
            expanded = expand_mask_hops(base, gc.indptr, gc.indices, hops)
        else:
            expanded = base
        if int(expanded.sum()) >= target_count:
            mask = expanded
            break
        # add a farthest new centroid from already chosen (by coords)
        # compute distance to nearest chosen
        pts = gc.coords[cand, :2]
        if chosen:
            chosen_pts = gc.coords[np.asarray(chosen, dtype=np.int32), :2]
            # nearest distance to chosen set
            d2 = np.min(np.sum((pts[:, None, :] - chosen_pts[None, :, :])**2, axis=-1), axis=1)
        else:
            d2 = np.ones(cand.shape[0])
        next_idx = int(cand[int(np.argmax(d2))])
        if next_idx in chosen:
            # fallback to random if stuck
            next_idx = int(rng.choice(cand))
        chosen.append(next_idx)
    else:
        mask = expanded  # from last iteration
    # Adjust to exact target
    mask = _match_target_count(mask, target_count, cand, rng)
    return mask


def _trajectory_sampling_mask(
    gc,
    target_count: int,
    hops: int,
    rng: np.random.Generator,
    num_trajs: int = 2,
    traj_len: int = 50,
    traj_min_len: Optional[int] = None,
) -> np.ndarray:
    cand = gc.fluid_cand if gc.fluid_cand.size > 0 else np.arange(gc.coords.shape[0], dtype=np.int32)
    if cand.size == 0:
        return np.zeros(gc.coords.shape[0], dtype=bool)

    indptr, indices, coords = gc.indptr, gc.indices, gc.coords
    n = coords.shape[0]
    core_mask = np.zeros(n, dtype=bool)     # nodes visited along trajectories
    visit_order: list[int] = []             # keeps order to trim from the end if needed

    tmin = int(traj_len // 2) if traj_min_len is None else int(traj_min_len)
    tmin = max(2, min(tmin, int(traj_len)))

    def neighbors(i: int) -> np.ndarray:
        return indices[indptr[i]:indptr[i+1]]

    def add_node(i: int):
        if not core_mask[i]:
            core_mask[i] = True
            visit_order.append(int(i))

    # helper to build a single car trajectory of length L
    def run_car(start_idx: int, L: int):
        cur, prev = start_idx, -1
        add_node(cur)
        for _ in range(max(0, L - 1)):
            nbrs = neighbors(cur)
            if nbrs.size == 0:
                break
            # prefer fluid candidates and unvisited neighbors (coverage!)
            nbrs = nbrs[np.isin(nbrs, cand)] if cand.size else nbrs
            unvisited = nbrs[~core_mask[nbrs]]
            if unvisited.size > 0:
                nbrs = unvisited
            # avoid immediate backtracking if there’s a choice
            if prev >= 0 and nbrs.size > 1:
                nbrs = nbrs[nbrs != prev] if np.any(nbrs != prev) else nbrs
            # score by distance from start (push outward), choose among top-k for variety
            vec = coords[nbrs, :2] - coords[start_idx, :2]
            d2 = np.sum(vec * vec, axis=1)
            k = min(3, nbrs.size)
            nxt = int(rng.choice(nbrs[np.argsort(-d2)[:k]]))
            prev, cur = cur, nxt
            add_node(cur)

    # 1) run the requested number of cars with variable lengths
    remaining = max(0, int(target_count))
    for _ in range(max(1, int(num_trajs))):
        if core_mask.sum() >= target_count:
            break
        start = int(rng.choice(cand[~core_mask[cand]]) if np.any(~core_mask[cand]) else rng.choice(cand))
        L = int(rng.integers(low=tmin, high=int(traj_len) + 1))
        run_car(start, L)

    # 2) if we’re still short after those cars, spawn more cars automatically
    #    (keeps the mask trajectory-like instead of adding random points)
    safety = 0
    while True:
        expanded = expand_mask_hops(core_mask, indptr, indices, hops) if hops > 0 else core_mask.copy()
        if int(expanded.sum()) >= target_count:
            mask = expanded
            break
        # add another short car from an unvisited candidate
        if safety > 10_000:  # guard
            mask = expanded
            break
        safety += 1
        remaining = target_count - int(expanded.sum())
        # choose a modest extra length so we converge smoothly
        extra_len = int(min(traj_len, max(tmin, remaining // 2)))
        start_choices = cand[~core_mask[cand]]
        start = int(rng.choice(start_choices) if start_choices.size else rng.choice(cand))
        run_car(start, extra_len)

    # 3) if we overshot after expansion, trim ONLY the expansion fringe (keep core intact)
    if int(mask.sum()) > target_count:
        keep = core_mask.copy()
        fringe = np.flatnonzero(mask & ~keep)
        need_drop = int(mask.sum()) - int(target_count)
        if need_drop > 0 and fringe.size > 0:
            drop = rng.choice(fringe, size=min(need_drop, fringe.size), replace=False)
            mask[drop] = False

    return mask



def _get_zoom_bounds(coords: np.ndarray, zoom_factor: float = 2.0) -> Tuple[float, float, float, float]:
    """
    Calculate zoom bounds centered on the domain center.
    """
    xmin, ymin = coords.min(axis=0)
    xmax, ymax = coords.max(axis=0)
    cx = (xmin + xmax) / 2
    cy = (ymin + ymax) / 2
    dx = (xmax - xmin) / zoom_factor
    dy = (ymax - ymin) / zoom_factor
    return cx - dx/2, cx + dx/2, cy - dy/2, cy + dy/2


# ----------------------------
# Plotting
# ----------------------------

def _plot_4x4_strategies(filepath: Path,
                         coords: np.ndarray,
                         gt: np.ndarray,
                         masks: Dict[str, np.ndarray],
                         preds: Dict[str, Optional[np.ndarray]],
                         angle: int,
                         obs_count: int):
    rows = ['GT', 'Random', 'Cloud', 'Trajectory']
    cols = ['Ux', 'Uy', 'Obs', 'Err Ux', 'Err Uy']
    fig, axes = plt.subplots(4, 5, figsize=(20, 12), constrained_layout=True)
    fig.suptitle(f"Angle {angle} | obs={obs_count}", fontsize=14, fontweight='bold')

    # Calculate zoom bounds
    z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

    # Compute component-wise global limits for better visual consistency
    all_vals = {'Ux': [gt[:, 0]], 'Uy': [gt[:, 1]]}
    for k, arr in preds.items():
        if arr is not None:
            all_vals['Ux'].append(arr[:, 0])
            all_vals['Uy'].append(arr[:, 1])
    vlims = {comp: (float(min([v.min() for v in vals])), float(max([v.max() for v in vals]))) for comp, vals in all_vals.items()}

    # Row 0: GT Ux, Uy, empty mask
    for j, comp in enumerate([0, 1]):
        vmin, vmax = vlims['Ux' if comp==0 else 'Uy']
        s = axes[0, j].scatter(coords[:, 0], coords[:, 1], c=gt[:, comp], cmap='RdBu_r', s=2.0, vmin=vmin, vmax=vmax)
        if j == 0:
            axes[0, j].set_ylabel(rows[0])
        axes[0, j].set_title(cols[j])
        axes[0, j].set_aspect('equal')
        axes[0, j].set_xlim(z_xmin, z_xmax)
        axes[0, j].set_ylim(z_ymin, z_ymax)
        axes[0, j].set_xticks([]); axes[0, j].set_yticks([])
        plt.colorbar(s, ax=axes[0, j], fraction=0.046, pad=0.02)
    axes[0, 2].axis('off')
    # Column title for Obs
    axes[0, 2].set_title('Obs')
    # Column title for Error
    axes[0, 3].axis('off'); axes[0, 3].set_title('Err Ux')
    axes[0, 4].axis('off'); axes[0, 4].set_title('Err Uy')

    # Helper for strategy rows (1..3)
    def draw_row(ridx: int, key: str, label: str):
        m = masks[key].astype(bool)
        pred = preds.get(key)
        for j, comp in enumerate([0, 1]):
            ax = axes[ridx, j]
            if pred is None:
                ax.text(0.5, 0.5, f"{label} N/A", transform=ax.transAxes, ha='center', va='center')
                ax.set_axis_off()
                continue
            vmin, vmax = vlims['Ux' if comp==0 else 'Uy']
            s = ax.scatter(coords[:, 0], coords[:, 1], c=pred[:, comp], cmap='RdBu_r', s=2.0, vmin=vmin, vmax=vmax)
            if j == 0:
                ax.set_ylabel(label)
            ax.set_aspect('equal')
            ax.set_xlim(z_xmin, z_xmax)
            ax.set_ylim(z_ymin, z_ymax)
            ax.set_xticks([]); ax.set_yticks([])
            plt.colorbar(s, ax=ax, fraction=0.046, pad=0.02)
            # Component-specific R-RMSE annotation (separate for Ux and Uy)
            try:
                num = float(np.sqrt(np.mean((pred[:, comp] - gt[:, comp])**2)))
                den = float(np.sqrt(np.mean((gt[:, comp])**2)))
                rr = (num / den) if den != 0.0 else 0.0
                ax.text(0.02, 0.98, f"R-RMSE={rr:.3f}", transform=ax.transAxes, va='top',
                        bbox=dict(boxstyle='round,pad=0.25', facecolor='white', alpha=0.8))
            except Exception:
                pass
        # Mask column
        axm = axes[ridx, 2]
        vals = m.astype(np.int32)
        cmap_mask = matplotlib.colors.ListedColormap(['lightgray','red'])
        sc_m = axm.scatter(coords[:, 0], coords[:, 1], c=vals, cmap=cmap_mask, vmin=0, vmax=1, s=0.2, alpha=0.9)
        axm.set_aspect('equal')
        axm.set_xlim(z_xmin, z_xmax)
        axm.set_ylim(z_ymin, z_ymax)
        axm.set_xticks([]); axm.set_yticks([])
        cbm = plt.colorbar(sc_m, ax=axm, fraction=0.046, pad=0.02)
        cbm.set_ticks([0,1]); cbm.set_ticklabels(['unobs','obs'])
        
        # Error columns (Ux, Uy)
        if pred is None:
            axes[ridx, 3].axis('off')
            axes[ridx, 4].axis('off')
        else:
            for j, comp in enumerate([0, 1]):
                axe = axes[ridx, 3 + j]
                diff = np.abs(pred[:, comp] - gt[:, comp])
                gt_mag = np.linalg.norm(gt, axis=-1)
                rel_error = diff / (gt_mag + 1e-6)
                sc_e = axe.scatter(coords[:, 0], coords[:, 1], c=rel_error, cmap='Reds', vmin=0, vmax=1, s=2.0)
                axe.set_aspect('equal')
                axe.set_xlim(z_xmin, z_xmax)
                axe.set_ylim(z_ymin, z_ymax)
                axe.set_xticks([]); axe.set_yticks([])
                plt.colorbar(sc_e, ax=axe, fraction=0.046, pad=0.02)

    draw_row(1, 'random', 'Random')
    draw_row(2, 'cloud', 'Cloud')
    draw_row(3, 'trajectory', 'Trajectory')

    filepath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filepath, dpi=150)
    plt.close(fig)


def _plot_4x4_strategies_grid(filepath: Path,
                              coords: np.ndarray,
                              gt: np.ndarray,
                              masks: Dict[str, np.ndarray],
                              preds: Dict[str, Optional[np.ndarray]],
                              angle: int,
                              obs_count: int,
                              nx: int = 300,
                              ny: int = 300):
    """Same as _plot_4x4_strategies but using imshow on a regular grid.
    Values are produced via scattered->grid interpolation.
    """
    rows = ['GT', 'Random', 'Cloud', 'Trajectory']
    cols = ['Ux', 'Uy', 'Obs', 'Err Ux', 'Err Uy']
    X, Y, (xmin, xmax, ymin, ymax) = _make_grid(coords, nx=nx, ny=ny)
    extent = (xmin, xmax, ymin, ymax)

    fig, axes = plt.subplots(4, 5, figsize=(20, 12),  constrained_layout=True)
    fig.suptitle(f"Angle {angle} | obs={obs_count}", fontsize=14, fontweight='bold')

    # Calculate zoom bounds
    z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

    # Compute component-wise global limits for better visual consistency
    all_vals = {'Ux': [gt[:, 0]], 'Uy': [gt[:, 1]]}
    for k, arr in preds.items():
        if arr is not None:
            all_vals['Ux'].append(arr[:, 0])
            all_vals['Uy'].append(arr[:, 1])
    vlims = {comp: (float(min([v.min() for v in vals])), float(max([v.max() for v in vals]))) for comp, vals in all_vals.items()}

    # Row 0: GT
    for j, comp in enumerate([0, 1]):
        Z = _scatter_to_grid(coords, gt[:, comp], X, Y)
        im = axes[0, j].imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vlims['Ux' if comp==0 else 'Uy'][0], vmax=vlims['Ux' if comp==0 else 'Uy'][1])
        axes[0, j].set_title(cols[j])
        axes[0, j].set_ylabel(rows[0])
        axes[0, j].set_aspect('equal')
        axes[0, j].set_xlim(z_xmin, z_xmax)
        axes[0, j].set_ylim(z_ymin, z_ymax)
        axes[0, j].set_xticks([]); axes[0, j].set_yticks([])
        plt.colorbar(im, ax=axes[0, j], fraction=0.046, pad=0.02)
    axes[0, 2].axis('off')
    # Column title for Obs
    axes[0, 2].set_title('Obs')
    # Column title for Error
    axes[0, 3].axis('off'); axes[0, 3].set_title('Err Ux')
    axes[0, 4].axis('off'); axes[0, 4].set_title('Err Uy')

    # Helper for strategy rows (1..3)
    def draw_row(ridx: int, key: str, label: str):
        m = masks[key].astype(bool)
        pred = preds.get(key)
        for j, comp in enumerate([0, 1]):
            ax = axes[ridx, j]
            if pred is None:
                ax.text(0.5, 0.5, f"{label} N/A", transform=ax.transAxes, ha='center', va='center')
                ax.set_axis_off()
                continue
            Z = _scatter_to_grid(coords, pred[:, comp], X, Y)
            vmin, vmax = vlims['Ux' if comp==0 else 'Uy']
            im = ax.imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin, vmax=vmax)
            if j == 0:
                ax.set_ylabel(label)
            ax.set_aspect('equal')
            ax.set_xlim(z_xmin, z_xmax)
            ax.set_ylim(z_ymin, z_ymax)
            ax.set_xticks([]); ax.set_yticks([])
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
            # Component-wise R-RMSE on imshow panels
            try:
                num = float(np.sqrt(np.mean((pred[:, comp] - gt[:, comp])**2)))
                den = float(np.sqrt(np.mean((gt[:, comp])**2)))
                rr = (num / den) if den != 0.0 else 0.0
                ax.text(0.02, 0.98, f"R-RMSE={rr:.3f}", transform=ax.transAxes, va='top',
                        bbox=dict(boxstyle='round,pad=0.25', facecolor='white', alpha=0.8))
            except Exception:
                pass
        # Mask column rendered on grid
        axm = axes[ridx, 2]
        M_grid = _scatter_to_grid(coords, m.astype(np.float32), X, Y)
        im = axm.imshow(M_grid, origin='lower', extent=extent, cmap=matplotlib.colors.ListedColormap(['lightgray','red']), vmin=0, vmax=1)
        axm.set_aspect('equal')
        axm.set_xlim(z_xmin, z_xmax)
        axm.set_ylim(z_ymin, z_ymax)
        axm.set_xticks([]); axm.set_yticks([])
        cb = plt.colorbar(im, ax=axm, fraction=0.046, pad=0.02)
        cb.set_ticks([0,1]); cb.set_ticklabels(['unobs','obs'])
        
        # Error columns (Ux, Uy)
        if pred is None:
            axes[ridx, 3].axis('off')
            axes[ridx, 4].axis('off')
        else:
            for j, comp in enumerate([0, 1]):
                axe = axes[ridx, 3 + j]
                diff = np.abs(pred[:, comp] - gt[:, comp])
                gt_mag = np.linalg.norm(gt, axis=-1)
                rel_error = diff / (gt_mag + 1e-6)
                Z_err = _scatter_to_grid(coords, rel_error, X, Y)
                im_e = axe.imshow(Z_err, origin='lower', extent=extent, cmap='Reds', vmin=0, vmax=1)
                axe.set_aspect('equal')
                axe.set_xlim(z_xmin, z_xmax)
                axe.set_ylim(z_ymin, z_ymax)
                axe.set_xticks([]); axe.set_yticks([])
                plt.colorbar(im_e, ax=axe, fraction=0.046, pad=0.02)

    draw_row(1, 'random', 'Random')
    draw_row(2, 'cloud', 'Cloud')
    draw_row(3, 'trajectory', 'Trajectory')

    filepath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filepath, dpi=150)
    plt.close(fig)


# ----------------------------
# Main
# ----------------------------

def main():
    p = argparse.ArgumentParser(description="Evaluate GenDA across observation sampling strategies")
    p.add_argument('--gen_da_checkpoint', type=str, required=True)
    p.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k')
    p.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc')
    p.add_argument('--batch_size', type=int, default=1)
    p.add_argument('--latent_size', type=int, default=64)
    p.add_argument('--eval_angle_stride', type=int, default=36)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--output_dir', type=str, default='eval_outputs/strategy_compare')
    # Observation count range
    p.add_argument('--obs_min', type=int, default=100, help='Minimum number of observations')
    p.add_argument('--obs_max', type=int, default=1000, help='Maximum number of observations (inclusive)')
    # Hops settings
    p.add_argument('--cloud_hops', type=int, default=1)
    p.add_argument('--traj_hops', type=int, default=1)
    p.add_argument('--traj_num', type=int, default=2)
    p.add_argument('--traj_len', type=int, default=50)
    p.add_argument('--traj_min_len', type=int, default=15,
                   help='Minimum steps per trajectory (car); actual length is uniform in [traj_min_len, traj_len]')
    # Plotting
    p.add_argument('--example_angles', type=str, default='')
    p.add_argument('--num_example_plots', type=int, default=None)
    # VTU grid options
    p.add_argument('--save_vtu', action='store_true', help='Save VTU files for plotted angles')
    p.add_argument('--grid_nx', type=int, default=300)
    p.add_argument('--grid_ny', type=int, default=300)
    args = p.parse_args()

    out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)
    examples_dir = out_dir / 'examples'; examples_dir.mkdir(parents=True, exist_ok=True)
    vtu_dir = out_dir / 'vtu'; vtu_dir.mkdir(parents=True, exist_ok=True)

    # Dataset
    test_ds = make_dataset(
        slice_root=args.slice_root,
        norm_stats_nc=args.norm_stats,
        batch_size=max(1, int(args.batch_size)),
        shuffle=False,
        seed=args.seed + 3,
        is_training=False,
        angle_stride=max(1, int(args.eval_angle_stride)),
        drop_remainder=True,
    )
    test_batches = list(test_ds)
    if not test_batches:
        raise RuntimeError('Empty test dataset')

    # Example batch/model
    example_batch = test_batches[0]
    model = _init_gen_da(example_batch, args.latent_size, seed=args.seed)
    model = _restore_nnx_model(model, Path(args.gen_da_checkpoint))

    # Metrics accumulator across range: {strategy: {metric: [...]}}
    strategies = ['random', 'cloud', 'trajectory']
    metrics: Dict[str, Dict[str, List[float]]] = {s: {k: [] for k in ['rrmse', 'ssim', 'cosine']} for s in strategies}
    used_counts: List[int] = []

    # Main evaluation: for each batch, draw an obs_count ~ Uniform[obs_min, obs_max]
    print(f"Evaluating strategies for obs_count in [{args.obs_min}, {args.obs_max}] across {len(test_batches)} batches...")
    for batch in test_batches:
        rng = np.random.default_rng(args.seed + int(batch['angle_deg'][0]))
        oc = int(rng.integers(low=int(args.obs_min), high=int(args.obs_max) + 1))
        used_counts.append(oc)
        
        # Prepare base inputs/forcings (no observations, we will inject masks)
        inputs, forcings, *_ = prepare_batch(
            batch,
            obs_frac=0.0,
            obs_neighbor_hops=0,
            obs_on_fluid_only=True,
            obs_seed=args.seed,
            obs_mode='random',
        )
        gs = forcings['graph_structures']
        gs0 = gs[0] if isinstance(gs, (list, tuple)) else gs

        # Graph cache for candidate nodes & adjacency
        z_field = batch.get('z', 35.0)
        z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
        gc = GraphCacheManager.get(gs0, z_value, fluid_only=True)

        # Build masks for each strategy with matched count
        mask_random = _random_sampling_mask(gc, oc, hops=0, rng=rng)
        mask_cloud = _cloud_sampling_mask(gc, oc, hops=int(args.cloud_hops), rng=rng)
        mask_traj = _trajectory_sampling_mask(
            gc, oc, hops=int(args.traj_hops), rng=rng,
            num_trajs=int(args.traj_num), traj_len=int(args.traj_len), traj_min_len=int(args.traj_min_len)
        )

        masks = {'random': mask_random, 'cloud': mask_cloud, 'trajectory': mask_traj}

        # Evaluate each strategy
        for strat, msk in masks.items():
            # inject mask into forcings
            obs_mask = msk.astype(np.float32)[None, :]  # (1,N)
            U_clean = np.array(inputs['U_field'])  # (1,N,2)
            obs_values = U_clean * obs_mask[..., None]
            forc = {
                **forcings,
                'U_field_guiding': jnp.array(obs_values),
                'obs_mask': jnp.array(obs_mask),
                'batch_size': U_clean.shape[0],
            }
            # boundary values to sampling
            forc['boundary_values'] = inputs['U_field']

            try:
                rng_key = jax.random.PRNGKey(args.seed)
                noisy_inputs = jax.random.normal(rng_key, inputs['U_field'].shape)
                pred = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forc)
                original = np.array(inputs['U_field'][0])
                predicted = np.array(pred[0])
                gs_eval = forc['graph_structures']
                if isinstance(gs_eval, (list, tuple)):
                    gs_eval = gs_eval[0]
                mvals = _evaluate_one_case_metrics(original, predicted, gs_eval)
                for k, v in mvals.items():
                    metrics[strat][k].append(v)
            except Exception as e:
                print(f"Strategy {strat} failed on a batch: {e}")

    # Write CSV
    csv_path = out_dir / 'metrics_by_strategy_and_obsrange.csv'
    with open(csv_path, 'w', newline='') as f:
        import csv
        w = csv.writer(f)
        w.writerow(['strategy', 'obs_min', 'obs_max', 'mean_obs_count', 'std_obs_count', 'mean_rrmse', 'std_rrmse', 'mean_ssim', 'std_ssim', 'mean_cosine', 'std_cosine', 'num_cases'])
        used_counts_arr = np.array(used_counts, dtype=np.int64) if used_counts else np.array([], dtype=np.int64)
        for s in strategies:
            r = np.array(metrics[s]['rrmse'], dtype=np.float64)
            g = np.array(metrics[s]['ssim'], dtype=np.float64)
            c = np.array(metrics[s]['cosine'], dtype=np.float64)
            w.writerow([
                s,
                int(args.obs_min),
                int(args.obs_max),
                float(used_counts_arr.mean()) if used_counts_arr.size else np.nan,
                float(used_counts_arr.std()) if used_counts_arr.size else np.nan,
                float(r.mean()) if r.size else np.nan,
                float(r.std()) if r.size else np.nan,
                float(g.mean()) if g.size else np.nan,
                float(g.std()) if g.size else np.nan,
                float(c.mean()) if c.size else np.nan,
                float(c.std()) if c.size else np.nan,
                int(r.size),
            ])
    print(f"Saved metrics CSV to {csv_path}")

    # Example plots
    if args.example_angles.strip():
        angle_list = [int(x) for x in args.example_angles.split(',') if x.strip()]
        if args.num_example_plots is not None and args.num_example_plots < len(angle_list):
            rng = np.random.default_rng(args.seed)
            angle_list = list(rng.choice(angle_list, size=args.num_example_plots, replace=False))

        for ang in angle_list:
            # find batch with this angle
            b: Optional[Dict[str, Any]] = None
            for tb in test_batches:
                try:
                    if int(tb['angle_deg'][0]) == ang:
                        b = tb; break
                except Exception:
                    continue
            if b is None:
                print(f"Angle {ang} not found; skipping example plot.")
                continue
            # pick obs_count from [obs_min, obs_max]
            oc = int(np.random.default_rng(args.seed + ang).integers(low=int(args.obs_min), high=int(args.obs_max)+1))
            # base prepare
            inputs, forcings, *_ = prepare_batch(
                b, obs_frac=0.0, obs_neighbor_hops=0, obs_on_fluid_only=True, obs_seed=args.seed, obs_mode='random')
            gs = forcings['graph_structures']
            gs0 = gs[0] if isinstance(gs, (list, tuple)) else gs
            z_field = b.get('z', 35.0)
            z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
            gc = GraphCacheManager.get(gs0, z_value, fluid_only=True)
            rng = np.random.default_rng(args.seed + ang)
            mask_random = _random_sampling_mask(gc, oc, hops=0, rng=rng)
            mask_cloud = _cloud_sampling_mask(gc, oc, hops=int(args.cloud_hops), rng=rng)
            mask_traj = _trajectory_sampling_mask(
                gc, oc, hops=int(args.traj_hops), rng=rng,
                num_trajs=int(args.traj_num), traj_len=int(args.traj_len), traj_min_len=int(args.traj_min_len)
            )
            masks = {'random': mask_random, 'cloud': mask_cloud, 'trajectory': mask_traj}

            coords = _coords_from_graph_struct(forcings['graph_structures'])
            gt = np.array(inputs['U_field'][0])
            pred_map: Dict[str, Optional[np.ndarray]] = {'random': None, 'cloud': None, 'trajectory': None}
            # Store the actually provided observation values (masked GT) per strategy
            obs_map: Dict[str, Optional[np.ndarray]] = {'random': None, 'cloud': None, 'trajectory': None}

            for strat, msk in masks.items():
                obs_mask = msk.astype(np.float32)[None, :]
                U_clean = np.array(inputs['U_field'])
                obs_values = U_clean * obs_mask[..., None]
                forc = {
                    **forcings,
                    'U_field_guiding': jnp.array(obs_values),
                    'obs_mask': jnp.array(obs_mask),
                    'batch_size': U_clean.shape[0],
                }
                forc['boundary_values'] = inputs['U_field']
                try:
                    rng_key = jax.random.PRNGKey(args.seed + ang)
                    noisy_inputs = jax.random.normal(rng_key, inputs['U_field'].shape)
                    pred = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forc)
                    pred_map[strat] = np.array(pred[0])
                    # Save the observation values (masked inputs) for VTU export
                    obs_map[strat] = np.array(obs_values[0])
                except Exception as e:
                    print(f"Sampling failed for {strat} @ angle {ang}: {e}")

            save_path = examples_dir / f"angle_{ang}_4x4_obs{oc}.png"
            _plot_4x4_strategies(save_path, coords, gt, masks, pred_map, ang, oc)
            print(f"Saved example 4x4 plot to {save_path}")

            # Also save an imshow-based version using a grid
            save_path_grid = examples_dir / f"angle_{ang}_4x4_grid_obs{oc}.png"
            _plot_4x4_strategies_grid(save_path_grid, coords, gt, masks, pred_map, ang, oc, nx=int(args.grid_nx), ny=int(args.grid_ny))
            print(f"Saved example 4x4 (grid) plot to {save_path_grid}")

            # Optionally save VTU (structured grid) with GT and strategies
            if args.save_vtu:
                if pv is None:
                    print("PyVista not available; skipping VTU export.")
                else:
                    try:
                        X, Y, _ = _make_grid(coords, nx=int(args.grid_nx), ny=int(args.grid_ny))
                        grid, ny, nx = _unstructured_from_rect_grid(X, Y)

                        # Map GT (point-associated data)
                        Ux_GT = _scatter_to_grid(coords, gt[:, 0], X, Y)
                        Uy_GT = _scatter_to_grid(coords, gt[:, 1], X, Y)
                        grid.point_data['Ux_GT'] = Ux_GT.astype(np.float32).ravel(order='C')
                        grid.point_data['Uy_GT'] = Uy_GT.astype(np.float32).ravel(order='C')

                        for key, pred in pred_map.items():
                            # mask field
                            m = masks[key].astype(np.float32)
                            M_grid = _scatter_to_grid(coords, m, X, Y)
                            grid.point_data[f'mask_{key}'] = M_grid.astype(np.float32).ravel(order='C')
                            # Observed (masked) values that were given to the model
                            if obs_map.get(key) is not None:
                                obs_vals = obs_map[key]
                                Ux_obs = _scatter_to_grid(coords, obs_vals[:, 0], X, Y)
                                Uy_obs = _scatter_to_grid(coords, obs_vals[:, 1], X, Y)
                                grid.point_data[f'Ux_obs_{key}'] = Ux_obs.astype(np.float32).ravel(order='C')
                                grid.point_data[f'Uy_obs_{key}'] = Uy_obs.astype(np.float32).ravel(order='C')
                            if pred is not None:
                                Ux = _scatter_to_grid(coords, pred[:, 0], X, Y)
                                Uy = _scatter_to_grid(coords, pred[:, 1], X, Y)
                                grid.point_data[f'Ux_{key}'] = Ux.astype(np.float32).ravel(order='C')
                                grid.point_data[f'Uy_{key}'] = Uy.astype(np.float32).ravel(order='C')

                        vtu_path = vtu_dir / f'angle_{ang}_obs{oc}.vtu'
                        grid.save(str(vtu_path))  # type: ignore
                        print(f"Saved VTU (unstructured) grid to {vtu_path}")
                    except Exception as e:
                        print(f"VTU export failed for angle {ang}: {e}")


if __name__ == '__main__':
    main()
