#!/usr/bin/env python3
"""
Compare GenDA reconstructions with MeshGraphNet baselines on a shared test set, computing standard error metrics and saving summary plots.
"""
import os
from pathlib import Path
import argparse
import time
import csv
from typing import Dict, Any, List, Optional, Tuple

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
from scipy.spatial import cKDTree, Delaunay
from scipy.interpolate import LinearNDInterpolator

import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax.experimental import mesh_utils

from gen_da.gen_da import GenSynth, SamplerConfig, NoiseConfig
from gen_da.denoiser import (
    DenoiserArchitectureConfig,
    NoiseEncoderConfig,
    AngleEncoderConfig,
)

from baselines.meshgraphnets import MeshGraphNet, MultiScaleMeshGraphNet
from training.graph_dataset import make_dataset
from training.obs_sampling import GraphCacheManager

# Reuse metrics from testing util
from training.train import _mean_angular_similarity, _graph_ssim_speed

try:
    # Prefer training.prepare_batch to ensure identical forcings
    from training.train import prepare_batch
except Exception:
    # Fallback if train.prepare_batch isn't importable
    from baselines.train_meshgraphnets import prepare_batch  # type: ignore

# Import POD baseline components (assuming baselines/lcsvd.py is available)
try:
    from baselines.lcsvd import SensorPOD, collect_case_map, interpolate_to_reference, vectorize_field
    POD_AVAILABLE = True
except ImportError:
    POD_AVAILABLE = False
    print("Warning: Could not import SensorPOD from baselines.lcsvd. POD baseline disabled.")


def _interp_field_to_reference(
    src_coords_xy: np.ndarray,
    tgt_coords_xy: np.ndarray,
    field: np.ndarray,
) -> np.ndarray:
    """Interpolate a (N_src,2) field from src_coords onto tgt_coords using
    the same approach as baselines/lcsvd.py: LinearNDInterpolator with
    nearest-neighbor fill for NaNs.

    Returns: (N_tgt,2)
    """
    src_coords_xy = np.asarray(src_coords_xy)
    tgt_coords_xy = np.asarray(tgt_coords_xy)
    field = np.asarray(field)
    if field.ndim != 2 or field.shape[1] != 2:
        raise ValueError(f"Expected field (N,2), got {field.shape}")

    try:
        src_tri = Delaunay(src_coords_xy)
    except Exception:
        src_tri = None

    tree = cKDTree(src_coords_xy)
    _, nn_indices = tree.query(tgt_coords_xy)

    if src_tri is None:
        return field[nn_indices].astype(np.float32)

    lin_interp = LinearNDInterpolator(src_tri, field)
    new_val = lin_interp(tgt_coords_xy)

    # Fill NaNs with nearest-neighbor
    mask_nan = np.isnan(new_val)
    if np.any(mask_nan):
        rows_with_nan = np.any(mask_nan, axis=1)
        new_val[rows_with_nan] = field[nn_indices[rows_with_nan]]

    return np.asarray(new_val).astype(np.float32)


def _init_gen_da(example_batch, latent_size: int, seed: int = 0):
    example_inputs, example_forcings, *_ = prepare_batch(example_batch)
    n_devices = max(1, jax.device_count())
    mesh = jax.sharding.Mesh(mesh_utils.create_device_mesh((n_devices,)), ('data',))
    den_cfg = DenoiserArchitectureConfig(
        latent_size=latent_size,
    )
    example_graph = example_forcings['graph_structures']
    if isinstance(example_graph, (list, tuple)):
        example_graph = example_graph[0]
    target_channels = 2
    guiding_channels = 2
    model = GenSynth(
        denoiser_architecture_config=den_cfg,
        sampler_config=SamplerConfig(),
        noise_config=NoiseConfig(),
        noise_encoder_config=NoiseEncoderConfig(),
        angle_encoder_config=AngleEncoderConfig(),
        rngs=nnx.Rngs(seed),
        mesh=mesh,
        example_graph_structures=example_graph,
        target_channels=target_channels,
        guiding_channels=guiding_channels,
    )
    return model, mesh


def _restore_nnx_model(model, ckpt_dir: Path):
    import orbax.checkpoint as ocp
    abs_graphdef, abs_rng_state, abs_other_state = nnx.split(model, nnx.RngState, ...)
    ckptr = ocp.PyTreeCheckpointer()
    restored_state = ckptr.restore(str(ckpt_dir), item=abs_other_state)
    nnx.update(model, restored_state)
    model.eval()
    return model


def _init_baseline(model_type: str, example_forcings: Dict[str, Any], latent_size: int, seed: int = 0):
    cfg = DenoiserArchitectureConfig(
        latent_size=latent_size,
    )
    example_graph = example_forcings['graph_structures']
    if isinstance(example_graph, (list, tuple)):
        example_graph = example_graph[0]
    if model_type == 'meshgraphnet':
        model = MeshGraphNet(cfg, nnx.Rngs(seed), None, example_graph, target_channels=2, guiding_channels=2)
    elif model_type == 'multiscale':
        model = MultiScaleMeshGraphNet(cfg, nnx.Rngs(seed), None, example_graph, target_channels=2, guiding_channels=2)
    else:
        raise ValueError("Unknown baseline type")
    return model


def _obs_fraction_for_count(batch: Dict[str, Any], obs_count: int, fluid_only: bool = True) -> float:
    gs = batch['graph_structures']
    g = gs if isinstance(gs, dict) else gs[0]
    z_field = batch.get('z', 35.0)
    z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
    gc = GraphCacheManager.get(g, z_value, fluid_only=fluid_only)
    denom = max(1, gc.fluid_cand.size)
    return float(np.clip(obs_count / denom, 1.0/denom, 1.0))


def _pick_batch_for_angle(test_batches: List[Dict[str, Any]], angle: int) -> Optional[Dict[str, Any]]:
    for b in test_batches:
        try:
            a = int(b['angle_deg'][0])
            if a == angle:
                return b
        except Exception:
            continue
    return None


def _coords_from_graph_struct(gs: Dict[str, Any], z_value: Optional[float] = None, structures_root: Optional[str] = None) -> np.ndarray:
    """Return best-available physical XY coordinates for plotting.
    Preference order:
      1) If structures_root and z_value provided, load raw slice_xy.npy from disk
      2) Fallback to fields inside graph_structures dict, preferring less-normalized keys
    """
    if isinstance(gs, (list, tuple)):
        gs = gs[0]
    # Determine expected number of nodes from in-memory graph struct
    expected_n = None
    try:
        if 'original_coordinates' in gs and gs['original_coordinates'] is not None:
            expected_n = int(np.asarray(gs['original_coordinates']).shape[0])
        elif 'reduced_coordinates' in gs and gs['reduced_coordinates'] is not None:
            expected_n = int(np.asarray(gs['reduced_coordinates']).shape[0])
        elif 'node_types' in gs and gs['node_types'] is not None:
            expected_n = int(np.asarray(gs['node_types']).shape[0])
        elif 'num_o_nodes' in gs:
            expected_n = int(gs['num_o_nodes'])
    except Exception:
        expected_n = None

    # Try loading raw coords from disk if possible (bypasses any normalization applied in the dataset)
    if structures_root is not None and z_value is not None:
        try:
            raw_path = Path(structures_root) / f"z_{int(z_value)}" / "slice_xy.npy"
            raw = np.load(raw_path)
            if expected_n is None or int(raw.shape[0]) == expected_n:
                return raw[:, :2] if raw.shape[1] >= 2 else np.pad(raw, ((0,0),(0,max(0,2-raw.shape[1]))))[:, :2]
            # if sizes mismatch, fall back to in-memory coords
        except Exception:
            pass
    # Prefer the most "raw/physical" coordinates available in the struct
    coord_keys = ['raw_coordinates', 'physical_coordinates', 'mesh_coordinates', 'original_coordinates', 'coordinates']
    coords = None
    for k in coord_keys:
        if k in gs and gs[k] is not None:
            coords = np.asarray(gs.get(k))
            break
    if coords is None:
        # Fallback to zero coords if missing
        return np.zeros((int(gs.get('num_o_nodes', 0)), 2), dtype=np.float32)
    # take XY
    if coords.shape[1] >= 2:
        return coords[:, :2]
    else:
        pad = np.zeros((coords.shape[0], 2), dtype=coords.dtype)
        pad[:, :coords.shape[1]] = coords
        return pad


# ----------------------------
# Grid utilities (for imshow-based extra plots)
# ----------------------------
try:
    from scipy.interpolate import griddata  # type: ignore
except Exception:
    griddata = None  # type: ignore


def _make_grid(coords_xy: np.ndarray, nx: int = 300, ny: int = 300):
    xmin, ymin = coords_xy.min(axis=0)
    xmax, ymax = coords_xy.max(axis=0)
    gx = np.linspace(xmin, xmax, nx)
    gy = np.linspace(ymin, ymax, ny)
    X, Y = np.meshgrid(gx, gy)
    return X, Y, (xmin, xmax, ymin, ymax)


def _scatter_to_grid(coords_xy: np.ndarray, values: np.ndarray, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    if griddata is None:
        xmin, xmax = X[0, 0], X[0, -1]
        ymin, ymax = Y[0, 0], Y[-1, 0]
        ny, nx = X.shape
        ix = np.clip(((coords_xy[:, 0] - xmin) / max(xmax - xmin, 1e-12) * (nx - 1)).round().astype(int), 0, nx - 1)
        iy = np.clip(((coords_xy[:, 1] - ymin) / max(ymax - ymin, 1e-12) * (ny - 1)).round().astype(int), 0, ny - 1)
        Z = np.full((ny, nx), np.nan, dtype=np.float32)
        Z[iy, ix] = values.astype(np.float32)
        # nearest fill rows/cols
        for y in range(ny):
            row = Z[y]
            good = ~np.isnan(row)
            if np.any(good):
                idx = np.where(good)[0]
                for x in range(nx):
                    if np.isnan(row[x]):
                        row[x] = row[idx[np.argmin(np.abs(idx - x))]]
        for x in range(nx):
            col = Z[:, x]
            good = ~np.isnan(col)
            if np.any(good):
                idx = np.where(good)[0]
                for y in range(ny):
                    if np.isnan(col[y]):
                        col[y] = col[idx[np.argmin(np.abs(idx - y))]]
        return Z
    else:
        Z = griddata(coords_xy, values, (X, Y), method='linear')  # type: ignore
        Znn = griddata(coords_xy, values, (X, Y), method='nearest')  # type: ignore
        Z[np.isnan(Z)] = Znn[np.isnan(Z)]
        return Z


def _evaluate_one_case_metrics(original_data: np.ndarray, predicted_data: np.ndarray, gs: Dict[str, Any]) -> Dict[str, float]:
    diff = predicted_data - original_data
    rmse_error = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean(original_data**2)))
    rel_rmse_error = rmse_error / denom if denom != 0 else 0.0
    cosine = _mean_angular_similarity(original_data, predicted_data)
    senders = np.asarray(gs['o2o_senders']).astype(np.int64)
    receivers = np.asarray(gs['o2o_receivers']).astype(np.int64)
    speed_true = np.linalg.norm(original_data, axis=-1).astype(np.float64)
    speed_pred = np.linalg.norm(predicted_data, axis=-1).astype(np.float64)
    ssim = _graph_ssim_speed(speed_true, speed_pred, senders, receivers, include_self=True)
    return {"rrmse": rel_rmse_error, "cosine": float(cosine), "ssim": float(ssim)}


def _vector_to_mag(arr: np.ndarray) -> np.ndarray:
    """arr: (N,2) -> (N,) magnitude"""
    return np.linalg.norm(arr, axis=-1)


def _rrmse_field(u_true: np.ndarray, u_pred: np.ndarray) -> float:
    """Relative RMSE over the vector field (Ux,Uy)."""
    diff = (u_pred - u_true).astype(np.float64)
    rmse = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean((u_true.astype(np.float64))**2)))
    return rmse / denom if denom != 0.0 else 0.0


def _get_zoom_bounds(coords: np.ndarray, zoom_factor: float = 2.0) -> Tuple[float, float, float, float]:
    """
    Calculate zoom bounds centered on the domain center.
    """
    xmin, ymin = coords.min(axis=0)
    xmax, ymax = coords.max(axis=0)
    cx = (xmin + xmax) / 2
    cy = (ymin + ymax) / 2
    dx = (xmax - xmin) / zoom_factor
    dy = (ymax - ymin) / zoom_factor
    return cx - dx/2, cx + dx/2, cy - dy/2, cy + dy/2


def _pod_predict_on_current_mesh(
    pod_model: "SensorPOD",
    pod_ref_coords: np.ndarray,
    pod_ref_tree: "cKDTree",
    curr_coords: np.ndarray,
    u_curr: np.ndarray,
    obs_mask_1d: np.ndarray,
) -> Optional[np.ndarray]:
    """Reconstruct POD on a per-batch mesh while matching `lcsvd.py` semantics.

    In `lcsvd.py`, sensors index into the *same mesh* the POD modes were fit on, and
    `y_obs` is taken from that mesh: `y_obs = vectorize_field(field)[obs_idx]`.

    Here, POD modes are fit on `pod_ref_coords` (reference mesh). For each batch:
    1) interpolate GT from current mesh onto reference mesh via nearest-neighbor
    2) map sensor locations to reference indices
    3) compute `y_obs` from the interpolated reference-mesh field (matches `lcsvd.py`)
    4) reconstruct on reference mesh and map prediction back to current mesh
    """
    if pod_model is None or pod_ref_tree is None:
        return None

    if u_curr.ndim != 2 or u_curr.shape[1] != 2:
        return None

    # Ensure mask is 1D bool of length N
    m = np.asarray(obs_mask_1d).astype(bool)
    if m.ndim != 1:
        m = m.reshape(-1)
    if m.size != u_curr.shape[0]:
        return None

    sensor_idx_curr = np.where(m)[0]
    if sensor_idx_curr.size == 0:
        return None

    # 1) Interpolate GT onto reference mesh (Linear approx, fallback to Nearest)
    # Using simple NN (u_curr[tree.query]) causes jagged fields and poor POD projection.
    u_ref = _interp_field_to_reference(curr_coords, pod_ref_coords, u_curr)

    # 2) map sensor coords to reference mesh indices (NN)
    sensor_coords = curr_coords[sensor_idx_curr]
    _, sensor_idx_ref = pod_ref_tree.query(sensor_coords, k=1)
    sensor_idx_ref = np.atleast_1d(sensor_idx_ref).flatten().astype(np.int64)

    # 3) build obs_idx and y_obs exactly like `lcsvd.py`
    obs_idx = np.empty(2 * sensor_idx_ref.size, dtype=np.int64)
    obs_idx[0::2] = 2 * sensor_idx_ref
    obs_idx[1::2] = 2 * sensor_idx_ref + 1
    y_obs = vectorize_field(u_ref)[obs_idx]

    # 4) reconstruct on ref mesh, then map back to current mesh
    pred_ref = pod_model.reconstruct_from_sensors(obs_idx, y_obs)
    # _, curr_to_ref = pod_ref_tree.query(curr_coords, k=1)
    # pred_curr = pred_ref[curr_to_ref]
    pred_curr = _interp_field_to_reference(pod_ref_coords, curr_coords, pred_ref)
    
    return pred_curr


def _plot_5x3_example(
    filepath: Path,
    coords: np.ndarray,
    u_gt: np.ndarray,
    obs_mask: np.ndarray,
    preds: Dict[str, Optional[np.ndarray]],
    angle: int,
    obs_count: int,
):
    """Make a 5x3 panel: columns Ux, Uy, Error; rows: GT, Obs mask, MGN, MultiScale, GenDA, POD.
    If a model pred is None, draw a placeholder message.
    """
    rows = ['GT', 'Sensors Pos', 'MGN', 'MS-MGN', 'LCSVD', 'GenDA']
    cols = ['Ux', 'Uy', 'Err Ux', 'Err Uy']
    fig, axes = plt.subplots(len(rows), len(cols), figsize=(20, 16),  constrained_layout=True)
    fig.suptitle(f"Angle {angle} | obs={obs_count}", fontsize=14, fontweight='bold')

    # Calculate zoom bounds
    z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

    # Component-wise limits using GT and available preds
    all_vals = {'Ux': [u_gt[:, 0]], 'Uy': [u_gt[:, 1]]}
    for k, arr in preds.items():
        if arr is not None:
            all_vals['Ux'].append(arr[:, 0])
            all_vals['Uy'].append(arr[:, 1])
    vlims = {comp: (float(min([v.min() for v in vals])), float(max([v.max() for v in vals]))) for comp, vals in all_vals.items()}

    # Row 0: GT
    for j, comp in enumerate([0, 1]):
        vmin, vmax = vlims['Ux' if comp == 0 else 'Uy']
        s = axes[0, j].scatter(coords[:, 0], coords[:, 1], c=u_gt[:, comp], cmap='RdBu_r', s=2.0, vmin=vmin, vmax=vmax)
        axes[0, j].set_title(cols[j])
        axes[0, j].set_ylabel(rows[0])
        axes[0, j].set_aspect('equal')
        axes[0, j].set_xlim(z_xmin, z_xmax)
        axes[0, j].set_ylim(z_ymin, z_ymax)
        axes[0, j].set_xticks([]); axes[0, j].set_yticks([])
        plt.colorbar(s, ax=axes[0, j], fraction=0.046, pad=0.02)
    
    # GT Error (Empty)
    axes[0, 2].set_title(cols[2]); axes[0, 2].axis('off')
    axes[0, 3].set_title(cols[3]); axes[0, 3].axis('off')

    # Row 1: Obs mask (gray for unobserved, red for observed) with colorbar
    m = obs_mask.astype(bool)
    cmap_mask = mcolors.ListedColormap(['lightgray', 'red'])
    for j in range(2):
        vals = m.astype(np.int32)
        s = axes[1, j].scatter(coords[:, 0], coords[:, 1], c=vals, cmap=cmap_mask, vmin=0, vmax=1, s=0.2, alpha=0.9)
        if j == 0:
            axes[1, j].set_ylabel(rows[1])
        axes[1, j].set_aspect('equal')
        axes[1, j].set_xlim(z_xmin, z_xmax)
        axes[1, j].set_ylim(z_ymin, z_ymax)
        axes[1, j].set_xticks([]); axes[1, j].set_yticks([])
        cb = plt.colorbar(s, ax=axes[1, j], fraction=0.046, pad=0.02)
        cb.set_ticks([0, 1]); cb.set_ticklabels(['unobs', 'obs'])
    
    # Obs Error (Empty)
    axes[1, 2].axis('off')
    axes[1, 3].axis('off')

    # Helper to draw prediction rows
    def draw_pred_row(row_idx: int, key: str, label: str):
        pred = preds.get(key)
        # Ux, Uy
        for j, comp in enumerate([0, 1]):
            ax = axes[row_idx, j]
            if pred is None:
                ax.text(0.5, 0.5, f"{label} not available", transform=ax.transAxes, ha='center', va='center')
                ax.set_axis_off()
                continue
            vmin, vmax = vlims['Ux' if comp == 0 else 'Uy']
            s = ax.scatter(coords[:, 0], coords[:, 1], c=pred[:, comp], cmap='RdBu_r', s=2.0, vmin=vmin, vmax=vmax)
            if j == 0:
                ax.set_ylabel(label)
            ax.set_aspect('equal')
            ax.set_xlim(z_xmin, z_xmax)
            ax.set_ylim(z_ymin, z_ymax)
            ax.set_xticks([]); ax.set_yticks([])
            plt.colorbar(s, ax=ax, fraction=0.046, pad=0.02)
            # Component-wise R-RMSE annotation (Ux vs Uy differ)
            try:
                num = float(np.sqrt(np.mean((pred[:, comp] - u_gt[:, comp])**2)))
                den = float(np.sqrt(np.mean((u_gt[:, comp])**2)))
                rr_val = (num / den) if den != 0.0 else 0.0
                ax.text(0.02, 0.98, f"R-RMSE={rr_val:.3f}", transform=ax.transAxes, va='top',
                        bbox=dict(boxstyle='round,pad=0.25', facecolor='white', alpha=0.8))
            except Exception:
                pass
        
        # Error Fields (Ux, Uy)
        if pred is None:
            axes[row_idx, 2].axis('off')
            axes[row_idx, 3].axis('off')
        else:
            for j, comp in enumerate([0, 1]):
                ax = axes[row_idx, 2 + j]
                diff = np.abs(pred[:, comp] - u_gt[:, comp])
                gt_mag = np.linalg.norm(u_gt, axis=-1)
                rel_error = diff / (gt_mag + 1e-6)
                s = ax.scatter(coords[:, 0], coords[:, 1], c=rel_error, cmap='Reds', vmin=0, vmax=1, s=2.0)
                ax.set_aspect('equal')
                ax.set_xlim(z_xmin, z_xmax)
                ax.set_ylim(z_ymin, z_ymax)
                ax.set_xticks([]); ax.set_yticks([])
                plt.colorbar(s, ax=ax, fraction=0.046, pad=0.02)

    draw_pred_row(2, 'meshgraphnet', 'MGN')
    draw_pred_row(3, 'multiscale', 'MS-MGN')
    draw_pred_row(4, 'sensor_pod', 'LCSVD')
    draw_pred_row(5, 'gen_da', 'GenDA')

    # plt.tight_layout(rect=[0, 0, 1, 0.97])
    # fig.subplots_adjust(wspace=0.08, hspace=0.14, left=0.06, right=0.96)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filepath, dpi=150)
    plt.close(fig)


def _plot_5x3_example_grid(
    filepath: Path,
    coords: np.ndarray,
    u_gt: np.ndarray,
    obs_mask: np.ndarray,
    preds: Dict[str, Optional[np.ndarray]],
    angle: int,
    obs_count: int,
    nx: int = 300,
    ny: int = 300,
):
    """imshow-based variant of the 5x3 example plot using gridded data."""
    rows = ['GT', 'Sensors Pos', 'MGN', 'MS-MGN', 'LCSVD', 'GenDA']
    cols = ['Ux', 'Uy', 'Err Ux', 'Err Uy']
    X, Y, (xmin, xmax, ymin, ymax) = _make_grid(coords, nx=nx, ny=ny)
    extent = (xmin, xmax, ymin, ymax)
    fig, axes = plt.subplots(len(rows), len(cols), figsize=(20, 16),  constrained_layout=True)
    fig.suptitle(f"Angle {angle} | obs={obs_count}", fontsize=14, fontweight='bold')

    # Calculate zoom bounds
    z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

    # Component-wise limits using GT and available preds
    all_vals = {'Ux': [u_gt[:, 0]], 'Uy': [u_gt[:, 1]]}
    for k, arr in preds.items():
        if arr is not None:
            all_vals['Ux'].append(arr[:, 0])
            all_vals['Uy'].append(arr[:, 1])
    vlims = {comp: (float(min([v.min() for v in vals])), float(max([v.max() for v in vals]))) for comp, vals in all_vals.items()}

    # Row 0 (GT)
    for j, comp in enumerate([0, 1]):
        Z = _scatter_to_grid(coords, u_gt[:, comp], X, Y)
        vmin, vmax = vlims['Ux' if comp == 0 else 'Uy']
        im = axes[0, j].imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin, vmax=vmax)
        axes[0, j].set_title(cols[j])
        axes[0, j].set_ylabel(rows[0])
        axes[0, j].set_aspect('equal')
        axes[0, j].set_xlim(z_xmin, z_xmax)
        axes[0, j].set_ylim(z_ymin, z_ymax)
        axes[0, j].set_xticks([]); axes[0, j].set_yticks([])
        plt.colorbar(im, ax=axes[0, j], fraction=0.046, pad=0.02)
    
    # GT Error (Empty)
    axes[0, 2].set_title(cols[2]); axes[0, 2].axis('off')
    axes[0, 3].set_title(cols[3]); axes[0, 3].axis('off')

    # Row 1: Obs mask
    M = _scatter_to_grid(coords, obs_mask.astype(np.float32), X, Y)
    cmap_mask = mcolors.ListedColormap(['lightgray', 'red'])
    for j in range(2):
        im = axes[1, j].imshow(M, origin='lower', extent=extent, cmap=cmap_mask, vmin=0, vmax=1)
        if j == 0:
            axes[1, j].set_ylabel(rows[1])
        axes[1, j].set_aspect('equal')
        axes[1, j].set_xlim(z_xmin, z_xmax)
        axes[1, j].set_ylim(z_ymin, z_ymax)
        axes[1, j].set_xticks([]); axes[1, j].set_yticks([])
        cb = plt.colorbar(im, ax=axes[1, j], fraction=0.046, pad=0.02)
        cb.set_ticks([0, 1]); cb.set_ticklabels(['unobs', 'obs'])
    
    # Obs Error (Empty)
    axes[1, 2].axis('off')
    axes[1, 3].axis('off')

    def draw_pred_row(row_idx: int, key: str, label: str):
        pred = preds.get(key)
        # Ux, Uy
        for j, comp in enumerate([0, 1]):
            ax = axes[row_idx, j]
            if pred is None:
                ax.text(0.5, 0.5, f"{label} not available", transform=ax.transAxes, ha='center', va='center')
                ax.set_axis_off();
                continue
            Z = _scatter_to_grid(coords, pred[:, comp], X, Y)
            vmin, vmax = vlims['Ux' if comp == 0 else 'Uy']
            im = ax.imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin, vmax=vmax)
            if j == 0:
                ax.set_ylabel(label)
            ax.set_aspect('equal')
            ax.set_xlim(z_xmin, z_xmax)
            ax.set_ylim(z_ymin, z_ymax)
            ax.set_xticks([]); ax.set_yticks([])
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
            try:
                num = float(np.sqrt(np.mean((pred[:, comp] - u_gt[:, comp])**2)))
                den = float(np.sqrt(np.mean((u_gt[:, comp])**2)))
                rr_val = (num / den) if den != 0 else 0.0
                ax.text(0.02, 0.98, f"R-RMSE={rr_val:.3f}", transform=ax.transAxes, va='top',
                        bbox=dict(boxstyle='round,pad=0.25', facecolor='white', alpha=0.8))
            except Exception:
                pass
        
        # Error Fields (Ux, Uy)
        if pred is None:
            axes[row_idx, 2].axis('off')
            axes[row_idx, 3].axis('off')
        else:
            for j, comp in enumerate([0, 1]):
                ax = axes[row_idx, 2 + j]
                diff = np.abs(pred[:, comp] - u_gt[:, comp])
                gt_mag = np.linalg.norm(u_gt, axis=-1)
                rel_error = diff / (gt_mag + 1e-6)
                Z = _scatter_to_grid(coords, rel_error, X, Y)
                im = ax.imshow(Z, origin='lower', extent=extent, cmap='Reds', vmin=0, vmax=1)
                ax.set_aspect('equal')
                ax.set_xlim(z_xmin, z_xmax)
                ax.set_ylim(z_ymin, z_ymax)
                ax.set_xticks([]); ax.set_yticks([])
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02)

    draw_pred_row(2, 'meshgraphnet', 'MGN')
    draw_pred_row(3, 'multiscale', 'MS-MGN')
    draw_pred_row(4, 'sensor_pod', 'LCSVD')
    draw_pred_row(5, 'gen_da', 'GenDA')

    filepath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filepath, dpi=150)
    plt.close(fig)


def _plot_5x5_grid(
    filepath: Path,
    coords_list: List[np.ndarray],
    grid_data: Dict[str, List[Optional[np.ndarray]]],
    angles: List[int],
    obs_counts: List[int],
):
    """Make a 5x(2*N) panel: columns are angles showing |U| and Error |U|; rows GT, Obs mask, MGN, MultiScale, GenDA.
    grid_data keys:
      - 'gt': list of (N,2)
      - 'mask': list of (N,) bool/float
      - 'meshgraphnet': list of (N,2) or None
      - 'multiscale': list of (N,2) or None
      - 'gen_da': list of (N,2) or None
    obs_counts: obs count used per column (for title annotation).
    """
    row_names = ['GT', 'Sensors Pos', 'MGN', 'MS-MGN', 'LCSVD', 'GenDA']
    n_angles = len(angles)
    # 2 columns per angle: |U| and Error
    fig, axes = plt.subplots(6, 2 * n_angles, figsize=(5.2 * n_angles, 14), constrained_layout=True)

    for col_idx, (angle, oc) in enumerate(zip(angles, obs_counts)):
        # Map to subplot columns: 2*col_idx, 2*col_idx+1
        c_mag = 2 * col_idx
        c_err = 2 * col_idx + 1
        
        coords = coords_list[col_idx]
        # Calculate zoom bounds
        z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

        # Compute per-angle vmin/vmax across GT + available model preds
        mags_col = []
        for key in ['gt', 'meshgraphnet', 'multiscale', 'gen_da']:
            arr = grid_data.get(key, [None]*n_angles)[col_idx]
            if arr is not None:
                mags_col.append(_vector_to_mag(arr))
        if mags_col:
            vmin_col = float(np.min([m.min() for m in mags_col]))
            vmax_col = float(np.max([m.max() for m in mags_col]))
        else:
            vmin_col, vmax_col = 0.0, 1.0

        # Row 0: GT
        gt = grid_data['gt'][col_idx]
        sc = axes[0, c_mag].scatter(coords[:, 0], coords[:, 1], c=_vector_to_mag(gt), cmap='RdBu_r', s=2.0, vmin=vmin_col, vmax=vmax_col)
        # Title with bold angle and degree symbol
        ang_label = f"{angle}º\nobs={oc}"
        axes[0, c_mag].set_title(ang_label, fontweight='bold')
        axes[0, c_mag].set_aspect('equal')
        axes[0, c_mag].set_xlim(z_xmin, z_xmax)
        axes[0, c_mag].set_ylim(z_ymin, z_ymax)
        axes[0, c_mag].set_xticks([]); axes[0, c_mag].set_yticks([])
        plt.colorbar(sc, ax=axes[0, c_mag], fraction=0.046, pad=0.02)
        
        # GT Error (Empty)
        axes[0, c_err].set_title("Err |U|")
        axes[0, c_err].axis('off')

        # Row 1: mask (with colorbar)
        m = np.asarray(grid_data['mask'][col_idx]).astype(bool)
        cmap_mask = mcolors.ListedColormap(['lightgray', 'red'])
        vals = m.astype(np.int32)
        sc_mask = axes[1, c_mag].scatter(coords[:, 0], coords[:, 1], c=vals, cmap=cmap_mask, vmin=0, vmax=1, s=2.0, alpha=0.9)
        axes[1, c_mag].set_aspect('equal')
        axes[1, c_mag].set_xlim(z_xmin, z_xmax)
        axes[1, c_mag].set_ylim(z_ymin, z_ymax)
        axes[1, c_mag].set_xticks([]); axes[1, c_mag].set_yticks([])
        cbm = plt.colorbar(sc_mask, ax=axes[1, c_mag], fraction=0.046, pad=0.02)
        cbm.set_ticks([0, 1]); #cbm.set_ticklabels(['unobs', 'obs'])
        
        # Mask Error (Empty)
        axes[1, c_err].axis('off')

        # Row 2..4: models
        def draw_row(ridx: int, key: str, label: str):
            arr = grid_data.get(key, [None]*n_angles)[col_idx]
            ax_m = axes[ridx, c_mag]
            ax_e = axes[ridx, c_err]
            
            if arr is None:
                ax_m.text(0.5, 0.5, f"{label} N/A", transform=ax_m.transAxes, ha='center', va='center')
                ax_m.set_axis_off()
                ax_e.axis('off')
            else:
                # Magnitude
                sc2 = ax_m.scatter(coords[:, 0], coords[:, 1], c=_vector_to_mag(arr), cmap='RdBu_r', s=2.0, vmin=vmin_col, vmax=vmax_col)
                ax_m.set_aspect('equal')
                ax_m.set_xlim(z_xmin, z_xmax)
                ax_m.set_ylim(z_ymin, z_ymax)
                ax_m.set_xticks([]); ax_m.set_yticks([])
                plt.colorbar(sc2, ax=ax_m, fraction=0.046, pad=0.02)
                # Annotate R-RMSE for this column/model vs GT (vector field-wise)
                try:
                    rr_val = _rrmse_field(gt, arr)
                    ax_m.text(0.02, 0.98, f"R-RMSE={rr_val:.3f}", transform=ax_m.transAxes, va='top',
                            bbox=dict(boxstyle='round,pad=0.25', facecolor='white', alpha=0.8))
                except Exception:
                    pass
                
                # Error Map (|Pred - GT| / |GT|)
                diff_mag = np.linalg.norm(arr - gt, axis=-1)
                gt_mag = np.linalg.norm(gt, axis=-1)
                rel_error = diff_mag / (gt_mag + 1e-6)
                
                sc_e = ax_e.scatter(coords[:, 0], coords[:, 1], c=rel_error, cmap='Reds', vmin=0, vmax=1, s=2.0)
                ax_e.set_aspect('equal')
                ax_e.set_xlim(z_xmin, z_xmax)
                ax_e.set_ylim(z_ymin, z_ymax)
                ax_e.set_xticks([]); ax_e.set_yticks([])
                plt.colorbar(sc_e, ax=ax_e, fraction=0.046, pad=0.02)

        draw_row(2, 'meshgraphnet', 'MGN')
        draw_row(3, 'multiscale', 'MS-MGN')
        draw_row(4, 'sensor_pod', 'LCSVD')
        draw_row(5, 'gen_da', 'GenDA')

    # Row labels on the left
    for r, name in enumerate(row_names):
        axes[r, 0].set_ylabel(name)

    filepath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filepath, dpi=150)
    plt.close(fig)


def _plot_5x5_grid_imshow(
    filepath: Path,
    coords_list: List[np.ndarray],
    grid_data: Dict[str, List[Optional[np.ndarray]]],
    angles: List[int],
    obs_counts: List[int],
    nx: int = 300,
    ny: int = 300,
):
    """imshow variant of 5x5 where each column uses a rect grid with per-angle limits."""
    row_names = ['GT', 'Sensors Pos', 'MGN', 'MS-MGN', 'LCSVD', 'GenDA']
    n_angles = len(angles)
    fig, axes = plt.subplots(6, 2 * n_angles, figsize=(5.2 * n_angles, 14), constrained_layout=True)

    for col_idx, (angle, oc) in enumerate(zip(angles, obs_counts)):
        c_mag = 2 * col_idx
        c_err = 2 * col_idx + 1
        
        coords = coords_list[col_idx]
        # Calculate zoom bounds
        z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

        X, Y, (xmin, xmax, ymin, ymax) = _make_grid(coords, nx=nx, ny=ny)
        extent = (xmin, xmax, ymin, ymax)

        # Compute per-angle vmin/vmax (vector magnitude) across available arrays
        mags_col = []
        for key in ['gt', 'meshgraphnet', 'multiscale', 'gen_da']:
            arr = grid_data.get(key, [None]*n_angles)[col_idx]
            if arr is not None:
                mags_col.append(_vector_to_mag(arr))
        if mags_col:
            vmin_col = float(np.min([m.min() for m in mags_col]))
            vmax_col = float(np.max([m.max() for m in mags_col]))
        else:
            vmin_col, vmax_col = 0.0, 1.0

        # Row 0: GT
        gt = grid_data['gt'][col_idx]
        Zmag = _scatter_to_grid(coords, _vector_to_mag(gt), X, Y)
        im = axes[0, c_mag].imshow(Zmag, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin_col, vmax=vmax_col)
        ang_label = f"{angle}º\nobs={oc}"
        axes[0, c_mag].set_title(ang_label, fontweight='bold')
        axes[0, c_mag].set_aspect('equal')
        axes[0, c_mag].set_xlim(z_xmin, z_xmax)
        axes[0, c_mag].set_ylim(z_ymin, z_ymax)
        axes[0, c_mag].set_xticks([]); axes[0, c_mag].set_yticks([])
        plt.colorbar(im, ax=axes[0, c_mag], fraction=0.046, pad=0.02)
        
        # GT Error (Empty)
        axes[0, c_err].set_title("Err |U|")
        axes[0, c_err].axis('off')

        # Row 1: mask
        m = np.asarray(grid_data['mask'][col_idx]).astype(bool)
        M = _scatter_to_grid(coords, m.astype(np.float32), X, Y)
        cmap_mask = mcolors.ListedColormap(['lightgray', 'red'])
        im_m = axes[1, c_mag].imshow(M, origin='lower', extent=extent, cmap=cmap_mask, vmin=0, vmax=1)
        axes[1, c_mag].set_aspect('equal')
        axes[1, c_mag].set_xlim(z_xmin, z_xmax)
        axes[1, c_mag].set_ylim(z_ymin, z_ymax)
        axes[1, c_mag].set_xticks([]); axes[1, c_mag].set_yticks([])
        cbm = plt.colorbar(im_m, ax=axes[1, c_mag], fraction=0.046, pad=0.02)
        cbm.set_ticks([0, 1]); #cbm.set_ticklabels(['unobs', 'obs'])
        
        # Mask Error (Empty)
        axes[1, c_err].axis('off')

        # Rows 2..4
        def draw_row(ridx: int, key: str, label: str):
            arr = grid_data.get(key, [None]*n_angles)[col_idx]
            ax_m = axes[ridx, c_mag]
            ax_e = axes[ridx, c_err]
            
            if arr is None:
                ax_m.text(0.5, 0.5, f"{label} N/A", transform=ax_m.transAxes, ha='center', va='center')
                ax_m.set_axis_off()
                ax_e.axis('off')
            else:
                Z = _scatter_to_grid(coords, _vector_to_mag(arr), X, Y)
                im2 = ax_m.imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin_col, vmax=vmax_col)
                ax_m.set_aspect('equal')
                ax_m.set_xlim(z_xmin, z_xmax)
                ax_m.set_ylim(z_ymin, z_ymax)
                ax_m.set_xticks([]); ax_m.set_yticks([])
                plt.colorbar(im2, ax=ax_m, fraction=0.046, pad=0.02)
                try:
                    rr_val = _rrmse_field(gt, arr)
                    ax_m.text(0.02, 0.98, f"R-RMSE={rr_val:.3f}", transform=ax_m.transAxes, va='top',
                            bbox=dict(boxstyle='round,pad=0.25', facecolor='white', alpha=0.8))
                except Exception:
                    pass
                
                # Error Map
                diff_mag = np.linalg.norm(arr - gt, axis=-1)
                gt_mag = np.linalg.norm(gt, axis=-1)
                rel_error = diff_mag / (gt_mag + 1e-6)
                Z_err = _scatter_to_grid(coords, rel_error, X, Y)
                im_e = ax_e.imshow(Z_err, origin='lower', extent=extent, cmap='Reds', vmin=0, vmax=1)
                ax_e.set_aspect('equal')
                ax_e.set_xlim(z_xmin, z_xmax)
                ax_e.set_ylim(z_ymin, z_ymax)
                ax_e.set_xticks([]); ax_e.set_yticks([])
                plt.colorbar(im_e, ax=ax_e, fraction=0.046, pad=0.02)

        draw_row(2, 'meshgraphnet', 'MGN')
        draw_row(3, 'multiscale', 'MS-MGN')
        draw_row(4, 'sensor_pod', 'LCSVD')
        draw_row(5, 'gen_da', 'GenDA')

    for r, name in enumerate(row_names):
        axes[r, 0].set_ylabel(name)

    filepath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(filepath, dpi=150)
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(description="Compare reconstruction quality: GenSynth vs baselines")
    parser.add_argument('--gen_da_checkpoint', type=str, required=True)
    parser.add_argument('--baseline_mgn_checkpoint', type=str, default="")
    parser.add_argument('--baseline_multiscale_checkpoint', type=str, default="")
    parser.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k')
    parser.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--latent_size', type=int, default=64)
    parser.add_argument('--eval_angle_stride', type=int, default=36)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--output_dir', type=str, default='eval_outputs/reconstruction_compare')
    parser.add_argument('--graph_structures_root', type=str, default='structures_cropped_300k',
                        help='Root directory that contains z_*/slice_xy.npy (raw XY) for plotting with physical coords')
    # Observation count(s)
    parser.add_argument('--obs_count', type=int, default=3000, help='Legacy: single approx number of observed points per case')
    parser.add_argument('--obs_counts', type=str, default='', help='Comma-separated list of observation counts, e.g. "350,7000,15000"')
    parser.add_argument('--obs_neighbor_hops', type=int, default=0)
    parser.add_argument('--obs_mode', type=str, default='random', choices=['random','swarm'])
    # POD config
    parser.add_argument('--enable_pod', action='store_true', help='Include Sensor-POD baseline')
    parser.add_argument('--pod_train_z_values', type=str, default='15,20,28,45', help='Z-slices for POD training')
    parser.add_argument('--pod_n_modes', type=int, default=15)
    
    # Plotting controls
    parser.add_argument('--example_angles', type=str, default='', help='Comma-separated angles for 5x2 example plots (columns Ux,Uy)')
    parser.add_argument('--grid_angles', type=str, default='', help='Comma-separated angles for 5x5 magnitude grid')
    parser.add_argument('--num_example_plots', type=int, default=None, help='Limit number of example angle plots; if set, sample this many from example_angles')
    args = parser.parse_args()

    out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)
    examples_dir = out_dir / 'examples'
    grid_dir = out_dir / 'grid'

    # Build test dataset
    test_ds = make_dataset(
        slice_root=args.slice_root,
        norm_stats_nc=args.norm_stats,
        batch_size=max(1, int(args.batch_size)),
        shuffle=False,
        seed=args.seed + 2,
        is_training=False,
        angle_stride=max(1, int(args.eval_angle_stride)),
        drop_remainder=True,
    )
    test_batches = list(test_ds)
    if len(test_batches) == 0:
        raise RuntimeError("Empty test dataset")

    # Example batch for model inits
    example_batch = test_batches[0]
    example_inputs, example_forcings, *_ = prepare_batch(example_batch)

    # Init models and restore checkpoints
    gen_da, _ = _init_gen_da(example_batch, args.latent_size, seed=args.seed)
    gen_da = _restore_nnx_model(gen_da, Path(args.gen_da_checkpoint))

    mgn = None
    if args.baseline_mgn_checkpoint:
        mgn = _init_baseline('meshgraphnet', example_forcings, args.latent_size, seed=args.seed)
        mgn = _restore_nnx_model(mgn, Path(args.baseline_mgn_checkpoint))

    mgn_ms = None
    if args.baseline_multiscale_checkpoint:
        mgn_ms = _init_baseline('multiscale', example_forcings, args.latent_size, seed=args.seed)
        mgn_ms = _restore_nnx_model(mgn_ms, Path(args.baseline_multiscale_checkpoint))

    # Init POD if requested
    pod_model = None
    pod_ref_coords: Optional[np.ndarray] = None
    if args.enable_pod and POD_AVAILABLE:
        print("🚀 Initializing & Fitting SensorPOD Baseline...")
        try:
            # 1. Need reference mesh coordinates from test set (from example_batch)
            pod_ref_coords = _coords_from_graph_struct(example_forcings['graph_structures'], z_value=None, structures_root=None)
            
            # 2. Collect training snapshots from train Zs into the reference mesh
            train_zs = [float(z.strip()) for z in args.pod_train_z_values.split(',') if z.strip()]
            all_train_snapshots = []
            
            # We assume training data resides in slice_root with same normalization
            # We must load it
            needed_approx = set(range(1, 360, 4)) # sub-sample cases for speed
            
            for z_val in train_zs:
                print(f"   -> Loading POD training data for Z={z_val}...")
                ds_train = make_dataset(
                    slice_root=args.slice_root,
                    norm_stats_nc=args.norm_stats,
                    batch_size=8,
                    shuffle=False,
                    seed=args.seed,
                    is_training=False,
                    fixed_z=int(z_val),
                    drop_remainder=False,
                    angle_stride=1,
                )
                
                # Use lcsvd's collect tools
                case_map, src_coords, _ = collect_case_map(ds_train, needed_cases=needed_approx, 
                                                        max_iterations=500, verbose=False)
                
                if not case_map or src_coords is None:
                    print(f"      Warning: No data for Z={z_val}")
                    continue
                
                # Interpolate if mismatch
                if pod_ref_coords is not None and src_coords.shape == pod_ref_coords.shape and np.allclose(src_coords, pod_ref_coords, atol=1e-5):
                    interp_map = case_map
                else:
                    if pod_ref_coords is None:
                        raise ValueError("POD reference coordinates not initialized")
                    interp_map = interpolate_to_reference(src_coords, pod_ref_coords, case_map)
                
                for c in sorted(interp_map.keys()):
                    all_train_snapshots.append(interp_map[c])

            if all_train_snapshots:
                combined_train_map = {i: f for i, f in enumerate(all_train_snapshots)}
                pod_model = SensorPOD(n_modes=args.pod_n_modes)
                t0 = time.time()
                pod_model.fit_modes(combined_train_map)
                print(f"      Fitted POD with {len(all_train_snapshots)} snapshots in {time.time()-t0:.2f}s")
            else:
                print("      Error: No training snapshots collected for POD. Disabled.")
                
        except Exception as e:
            print(f"      POD Init failed: {e}")
            pod_model = None
            pod_ref_coords = None
            
    # Build KDTree for POD reference mesh if active
    pod_tree = None
    if pod_model is not None and pod_ref_coords is not None:
        print("      Building KDTree for POD reference mesh...")
        pod_tree = cKDTree(pod_ref_coords)

    # Parse obs counts
    if args.obs_counts.strip():
        obs_counts = [int(x) for x in args.obs_counts.split(',') if x.strip()]
    else:
        obs_counts = [int(args.obs_count)]

    # Results accumulators over counts
    model_names = ['gen_da'] + (["meshgraphnet"] if mgn is not None else []) + \
                  (["multiscale"] if mgn_ms is not None else []) + \
                  (["sensor_pod"] if pod_model is not None else [])
                  
    metrics_by_count: Dict[int, Dict[str, Dict[str, List[float]]]] = {
        oc: {m: {"rrmse": [], "ssim": [], "cosine": []} for m in model_names} for oc in obs_counts
    }

    # Main evaluation for each observation count
    for oc in obs_counts:
        print(f"Evaluating obs_count={oc} across {len(test_batches)} batches...")
        for batch in test_batches:
            f = _obs_fraction_for_count(batch, oc, fluid_only=True)
            inputs, forcings, *_ = prepare_batch(
                batch,
                obs_frac=f,
                obs_frac_min=f,
                obs_frac_max=f,
                obs_neighbor_hops=int(args.obs_neighbor_hops),
                obs_on_fluid_only=True,
                obs_seed=args.seed + int(oc),
                obs_mode='random',  # enforce random for quantitative compare
            )

            # GenDA sampling
            try:
                noise_shape = inputs['U_field'].shape
                rng_key = jax.random.PRNGKey(args.seed)
                noisy_inputs = jax.random.normal(rng_key, noise_shape)
                forcings['boundary_values'] = inputs['U_field']
                pred_g = gen_da.full_sampling(noisy_inputs=noisy_inputs, forcings=forcings)
                original_data = np.array(inputs['U_field'][0])
                predicted = np.array(pred_g[0])
                gs = forcings['graph_structures']
                if isinstance(gs, (list, tuple)):
                    gs = gs[0]
                m = _evaluate_one_case_metrics(original_data, predicted, gs)
                for k, v in m.items():
                    metrics_by_count[oc]['gen_da'][k].append(v)
            except Exception as e:
                print(f"GenDA sampling failed on a batch: {e}")

            # Baselines
            if mgn is not None:
                try:
                    pred_b = mgn(forcings)
                    predicted = np.array(pred_b[0])
                    original_data = np.array(inputs['U_field'][0])
                    gs = forcings['graph_structures']
                    if isinstance(gs, (list, tuple)):
                        gs = gs[0]
                    m = _evaluate_one_case_metrics(original_data, predicted, gs)
                    for k, v in m.items():
                        metrics_by_count[oc]['meshgraphnet'][k].append(v)
                except Exception as e:
                    print(f"MeshGraphNet inference failed: {e}")

            if mgn_ms is not None:
                try:
                    pred_b = mgn_ms(forcings)
                    predicted = np.array(pred_b[0])
                    original_data = np.array(inputs['U_field'][0])
                    gs = forcings['graph_structures']
                    if isinstance(gs, (list, tuple)):
                        gs = gs[0]
                    m = _evaluate_one_case_metrics(original_data, predicted, gs)
                    for k, v in m.items():
                        metrics_by_count[oc]['multiscale'][k].append(v)
                except Exception as e:
                    print(f"MultiScale inference failed: {e}")

            # POD
            if pod_model is not None and pod_tree is not None and pod_ref_coords is not None:
                try:
                    original_data = np.array(inputs['U_field'][0])
                    gs = forcings['graph_structures']
                    if isinstance(gs, (list, tuple)):
                        gs = gs[0]
                    curr_coords = _coords_from_graph_struct(gs, z_value=None, structures_root=None)
                    
                    # Robustly extract 1D mask for batch 0
                    _raw_mask = np.array(forcings['obs_mask'])
                    if _raw_mask.ndim == 3:
                        obs_mask_1d = (_raw_mask[0, :, 0] > 0.5)
                    else:
                        obs_mask_1d = (_raw_mask[0] > 0.5)

                    pred_pod = _pod_predict_on_current_mesh(
                        pod_model=pod_model,
                        pod_ref_coords=pod_ref_coords,
                        pod_ref_tree=pod_tree,
                        curr_coords=curr_coords,
                        u_curr=original_data,
                        obs_mask_1d=obs_mask_1d,
                    )
                    if pred_pod is not None:
                        m = _evaluate_one_case_metrics(original_data, pred_pod, gs)
                        for k, v in m.items():
                            metrics_by_count[oc]['sensor_pod'][k].append(v)
                except Exception as e:
                    print(f"POD Recon failed at oc={oc}: {e}")

    # Write CSV per obs_count per model
    csv_path = out_dir / 'metrics_by_model_and_obs.csv'
    with open(csv_path, 'w', newline='') as f:
        w = csv.writer(f)
        w.writerow(['obs_count', 'model', 'mean_rrmse', 'std_rrmse', 'mean_ssim', 'std_ssim', 'mean_cosine', 'std_cosine', 'num_cases'])
        for oc in obs_counts:
            for name in model_names:
                r = np.array(metrics_by_count[oc][name]['rrmse'], dtype=np.float64)
                s = np.array(metrics_by_count[oc][name]['ssim'], dtype=np.float64)
                c = np.array(metrics_by_count[oc][name]['cosine'], dtype=np.float64)
                w.writerow([
                    int(oc),
                    name,
                    float(r.mean()) if r.size else np.nan,
                    float(r.std()) if r.size else np.nan,
                    float(s.mean()) if s.size else np.nan,
                    float(s.std()) if s.size else np.nan,
                    float(c.mean()) if c.size else np.nan,
                    float(c.std()) if c.size else np.nan,
                    int(r.size),
                ])
    print(f"Saved metrics CSV to {csv_path}")

    # Simple summary plot for the last obs_count evaluated (for convenience)
    last_oc = obs_counts[-1]
    plt.figure(figsize=(8, 5))
    models = model_names
    x = np.arange(len(models))
    r = [np.mean(metrics_by_count[last_oc][m]['rrmse']) if metrics_by_count[last_oc][m]['rrmse'] else np.nan for m in models]
    s = [np.mean(metrics_by_count[last_oc][m]['ssim']) if metrics_by_count[last_oc][m]['ssim'] else np.nan for m in models]
    c = [np.mean(metrics_by_count[last_oc][m]['cosine']) if metrics_by_count[last_oc][m]['cosine'] else np.nan for m in models]
    width = 0.25
    plt.bar(x - width, r, width, label='R-RMSE')
    plt.bar(x, s, width, label='SSIM')
    plt.bar(x + width, c, width, label='Cosine Similarity')
    tick_labels = []
    for m in models:
        label = m
        if m == 'meshgraphnet': label = 'MGN'
        elif m == 'multiscale': label = 'MS-MGN'
        elif m == 'gen_da': label = 'GenDA'
        elif m == 'sensor_pod': label = 'LCSVD'
        tick_labels.append(label)
    plt.xticks(x, tick_labels)
    plt.title(f"Metrics summary @ obs_count={last_oc}")
    plt.ylabel('Metric')
    plt.grid(True, axis='y', alpha=0.3)
    plt.legend()
    plt.tight_layout()
    fig_path = out_dir / 'metrics_summary.png'
    plt.savefig(fig_path, dpi=150)
    plt.close()
    print(f"Saved summary plot to {fig_path}")

    # ----------------------------------------------------------
    # 5x2 example plots for selected angles
    # ----------------------------------------------------------
    if args.example_angles.strip():
        all_angles = [int(x) for x in args.example_angles.split(',') if x.strip()]
        if args.num_example_plots is not None and args.num_example_plots < len(all_angles):
            rng = np.random.default_rng(args.seed)
            all_angles = list(rng.choice(all_angles, size=args.num_example_plots, replace=False))

        for ang in all_angles:
            b = _pick_batch_for_angle(test_batches, ang)
            if b is None:
                print(f"Warning: angle {ang} not found in the test set; skipping 5x2 plot.")
                continue
            # draw obs_count uniformly from [min(obs_counts), max(obs_counts)]
            oc_min, oc_max = int(min(obs_counts)), int(max(obs_counts))
            oc = int(np.random.default_rng(args.seed + ang).integers(low=oc_min, high=oc_max + 1))
            f = _obs_fraction_for_count(b, oc, fluid_only=True)
            inputs, forcings, *_ = prepare_batch(
                b,
                obs_frac=f, obs_frac_min=f, obs_frac_max=f,
                obs_neighbor_hops=int(args.obs_neighbor_hops),
                obs_on_fluid_only=True,
                obs_seed=args.seed + ang + oc,
                obs_mode='random',
            )
            # Run inferences
            # Prefer raw XY from disk using z for plotting
            z_field = b.get('z', 35.0)
            z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
            coords = _coords_from_graph_struct(forcings['graph_structures'], z_value, args.graph_structures_root)
            u_gt = np.array(inputs['U_field'][0])
            
            # Robustly extract 1D mask
            _raw_m = np.array(forcings.get('obs_mask'))
            if _raw_m.ndim == 3:
                obs_mask = (_raw_m[0, :, 0] > 0.5)
            else:
                obs_mask = (_raw_m[0] > 0.5)

            if len(coords) != len(u_gt):
                coords = _coords_from_graph_struct(forcings['graph_structures'], z_value=None, structures_root=None)

            preds: Dict[str, Optional[np.ndarray]] = {'meshgraphnet': None, 'multiscale': None, 'gen_da': None, 'sensor_pod': None}
            # GenDA
            try:
                rng_key = jax.random.PRNGKey(args.seed + ang)
                noisy = jax.random.normal(rng_key, inputs['U_field'].shape)
                forcings['boundary_values'] = inputs['U_field']
                pred = gen_da.full_sampling(noisy_inputs=noisy, forcings=forcings)
                preds['gen_da'] = np.array(pred[0])
            except Exception as e:
                print(f"GenDA sampling failed for angle {ang}: {e}")
            # Baselines
            if mgn is not None:
                try:
                    pred = mgn(forcings)
                    preds['meshgraphnet'] = np.array(pred[0])
                except Exception as e:
                    print(f"MeshGraphNet failed for angle {ang}: {e}")
            if mgn_ms is not None:
                try:
                    pred = mgn_ms(forcings)
                    preds['multiscale'] = np.array(pred[0])
                except Exception as e:
                    print(f"MultiScale failed for angle {ang}: {e}")
            
            # POD
            if pod_model is not None and pod_tree is not None and pod_ref_coords is not None:
                try:
                    curr_coords_pod = _coords_from_graph_struct(forcings['graph_structures'], z_value=None, structures_root=None)
                    preds['sensor_pod'] = _pod_predict_on_current_mesh(
                        pod_model=pod_model,
                        pod_ref_coords=pod_ref_coords,
                        pod_ref_tree=pod_tree,
                        curr_coords=curr_coords_pod,
                        u_curr=u_gt,
                        obs_mask_1d=obs_mask,
                    )
                except Exception as e:
                    print(f"POD viz failed for angle {ang}: {e}")

            save_path = examples_dir / f"angle_{ang}_5x3_obs{oc}.png"
            _plot_5x3_example(save_path, coords, u_gt, obs_mask, preds, ang, oc)
            print(f"Saved 5x3 example plot to {save_path}")

            # Extra: imshow/grid version
            save_path_grid = examples_dir / f"angle_{ang}_5x3_grid_obs{oc}.png"
            _plot_5x3_example_grid(save_path_grid, coords, u_gt, obs_mask, preds, ang, oc)
            print(f"Saved 5x3 (grid) plot to {save_path_grid}")

    # ----------------------------------------------------------
    # 5x5 grid of |U| for multiple angles
    # ----------------------------------------------------------
    if args.grid_angles.strip():
        angles = [int(x) for x in args.grid_angles.split(',') if x.strip()]
        grid_data = {
            'gt': [],
            'mask': [],
            'meshgraphnet': [],
            'multiscale': [],
            'gen_da': [],
            'sensor_pod': [],
        }
        used_counts: List[int] = []
        coords_list: List[np.ndarray] = []
        for ang in angles:
            b = _pick_batch_for_angle(test_batches, ang)
            if b is None:
                print(f"Warning: angle {ang} not found; skipping column.")
                continue
            oc_min, oc_max = int(min(obs_counts)), int(max(obs_counts))
            oc = int(np.random.default_rng(args.seed + ang * 3).integers(low=oc_min, high=oc_max + 1))
            f = _obs_fraction_for_count(b, oc, fluid_only=True)
            inputs, forcings, *_ = prepare_batch(
                b,
                obs_frac=f, obs_frac_min=f, obs_frac_max=f,
                obs_neighbor_hops=int(args.obs_neighbor_hops),
                obs_on_fluid_only=True,
                obs_seed=args.seed + ang * 3 + oc,
                obs_mode='random',
            )
            # Compute z for this batch and load matching-size coords
            z_field = b.get('z', 35.0)
            z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
            u_gt = np.array(inputs['U_field'][0])
            coords_plot = _coords_from_graph_struct(forcings['graph_structures'], z_value, args.graph_structures_root)
            if len(coords_plot) != len(u_gt):
                coords_plot = _coords_from_graph_struct(forcings['graph_structures'], z_value=None, structures_root=None)
            coords_list.append(coords_plot)
            grid_data['gt'].append(u_gt)
            
            # Robustly extract 1D mask
            _raw_m = np.array(forcings.get('obs_mask'))
            if _raw_m.ndim == 3:
                _m1d = (_raw_m[0, :, 0] > 0.5)
            else:
                _m1d = (_raw_m[0] > 0.5)
            grid_data['mask'].append(_m1d)

            used_counts.append(oc)

            # Predictions
            # GenDA
            genda_arr: Optional[np.ndarray] = None
            try:
                rng_key = jax.random.PRNGKey(args.seed + ang * 7)
                noisy = jax.random.normal(rng_key, inputs['U_field'].shape)
                forcings['boundary_values'] = inputs['U_field']
                pred = gen_da.full_sampling(noisy_inputs=noisy, forcings=forcings)
                genda_arr = np.array(pred[0])
            except Exception as e:
                print(f"GenDA sampling failed for angle {ang}: {e}")
            grid_data['gen_da'].append(genda_arr)

            # MGN
            mgn_arr: Optional[np.ndarray] = None
            if mgn is not None:
                try:
                    pred = mgn(forcings)
                    mgn_arr = np.array(pred[0])
                except Exception as e:
                    print(f"MeshGraphNet failed for angle {ang}: {e}")
            grid_data['meshgraphnet'].append(mgn_arr)

            # MultiScale
            ms_arr: Optional[np.ndarray] = None
            if mgn_ms is not None:
                try:
                    pred = mgn_ms(forcings)
                    ms_arr = np.array(pred[0])
                except Exception as e:
                    print(f"MultiScale failed for angle {ang}: {e}")
            grid_data['multiscale'].append(ms_arr)

            # POD
            pod_arr: Optional[np.ndarray] = None
            if pod_model is not None and pod_tree is not None and pod_ref_coords is not None:
                try:
                    # Use coords from graph structure (not plotting/raw coords) to match POD reference domain
                    curr_coords_pod = _coords_from_graph_struct(forcings['graph_structures'], z_value=None, structures_root=None)
                    obs_mask_1d = grid_data['mask'][-1]
                    pod_arr = _pod_predict_on_current_mesh(
                        pod_model=pod_model,
                        pod_ref_coords=pod_ref_coords,
                        pod_ref_tree=pod_tree,
                        curr_coords=curr_coords_pod,
                        u_curr=u_gt,
                        obs_mask_1d=obs_mask_1d,
                    )
                except Exception as e:
                    print(f"POD viz failed for angle {ang}: {e}")
            grid_data['sensor_pod'].append(pod_arr)

        if len(grid_data['gt']) > 0:
            save_path = grid_dir / f"angles_{'-'.join(map(str, angles))}_5x5.png"
            _plot_5x5_grid(save_path, coords_list, grid_data, angles, used_counts)
            print(f"Saved 5x5 grid plot to {save_path}")

            # Extra: imshow/grid version
            save_path_grid = grid_dir / f"angles_{'-'.join(map(str, angles))}_5x5_grid.png"
            _plot_5x5_grid_imshow(save_path_grid, coords_list, grid_data, angles, used_counts)
            print(f"Saved 5x5 (grid) plot to {save_path_grid}")
        else:
            print("Grid plot skipped: no valid angles found.")


if __name__ == '__main__':
    main()
