#!/usr/bin/env python3
"""
Sweep CFG guidance scale (gamma) for GenDA and save reconstruction metrics and example plots.
"""
import os
import argparse
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
import time
import csv
import dataclasses
from typing import Tuple, List
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax.experimental import mesh_utils

from gen_da.gen_da import GenSynth, SamplerConfig, NoiseConfig
from gen_da.denoiser import (
    DenoiserArchitectureConfig,
    NoiseEncoderConfig,
    AngleEncoderConfig,
)

from training.graph_dataset import make_dataset
from training.train import _mean_angular_similarity, _graph_ssim_speed
from training.obs_sampling import GraphCacheManager

try:
    from training.train import prepare_batch
except Exception:
    from baselines.train_meshgraphnets import prepare_batch  # type: ignore


def build_model(example_batch, latent_size: int, seed: int = 0):
    example_inputs, example_forcings, *_ = prepare_batch(example_batch)
    n_devices = max(1, jax.device_count())
    mesh = jax.sharding.Mesh(mesh_utils.create_device_mesh((n_devices,)), ('data',))
    den_cfg = DenoiserArchitectureConfig(
        latent_size=latent_size,
    )
    example_graph = example_forcings['graph_structures']
    if isinstance(example_graph, (list, tuple)):
        example_graph = example_graph[0]
    target_channels = 2
    guiding_channels = 2
    model = GenSynth(
        denoiser_architecture_config=den_cfg,
        sampler_config=SamplerConfig(),
        noise_config=NoiseConfig(),
        noise_encoder_config=NoiseEncoderConfig(),
        angle_encoder_config=AngleEncoderConfig(),
        rngs=nnx.Rngs(seed),
        mesh=mesh,
        example_graph_structures=example_graph,
        target_channels=target_channels,
        guiding_channels=guiding_channels,
    )
    return model


def restore_model(model, ckpt_dir: Path):
    import orbax.checkpoint as ocp
    # abs_graphdef, abs_rng_state, abs_other_state = nnx.split(model, nnx.RngState, ...)
    # Splitting with ellipsis to capture all state
    abs_graphdef, abs_state = nnx.split(model)
    ckptr = ocp.PyTreeCheckpointer()
    
    # We try to restore into the state structure
    # Note: nnx.split might return (GraphDef, State) or (GraphDef, Rngs, State) depending on version/usage.
    # Based on compare_observations_sweep.py, it uses:
    # abs_graphdef, abs_rng_state, abs_other_state = nnx.split(model, nnx.RngState, ...)
    # Let's try to match that pattern if possible, or just use the whole state.
    
    try:
        # Attempt minimal split pattern if possible or full split
        # Using the pattern from the provided file:
        abs_graphdef, abs_rng_state, abs_other_state = nnx.split(model, nnx.RngState, ...)
        restored_state = ckptr.restore(str(ckpt_dir), item=abs_other_state)
        nnx.update(model, restored_state)
    except Exception:
        # Fallback if the split pattern doesn't match
        print("Fallback restore...")
        _, state = nnx.split(model)
        restored_state = ckptr.restore(str(ckpt_dir), item=state)
        nnx.update(model, restored_state)
        
    model.eval()
    return model


# ----------------------------
# Coords and grid helpers
# ----------------------------

def _coords_from_graph_struct(gs) -> np.ndarray:
    if isinstance(gs, (list, tuple)):
        gs = gs[0]
    try:
        coords = np.asarray(gs.get('original_coordinates'))
        if coords is not None:
            return coords[:, :2] if coords.shape[1] >= 2 else np.pad(coords, ((0,0),(0,max(0,2-coords.shape[1]))))[:, :2]
    except Exception:
        pass
    return np.zeros((int(gs.get('num_o_nodes', 0)), 2), dtype=np.float32)


try:
    from scipy.interpolate import griddata  # type: ignore
except Exception:
    griddata = None  # type: ignore


def _make_grid(coords_xy: np.ndarray, nx: int = 300, ny: int = 300):
    xmin, ymin = coords_xy.min(axis=0)
    xmax, ymax = coords_xy.max(axis=0)
    gx = np.linspace(xmin, xmax, nx)
    gy = np.linspace(ymin, ymax, ny)
    X, Y = np.meshgrid(gx, gy)
    return X, Y, (xmin, xmax, ymin, ymax)


def _scatter_to_grid(coords_xy: np.ndarray, values: np.ndarray, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    if griddata is None:
        xmin, xmax = X[0, 0], X[0, -1]
        ymin, ymax = Y[0, 0], Y[-1, 0]
        ny, nx = X.shape
        ix = np.clip(((coords_xy[:, 0] - xmin) / max(xmax - xmin, 1e-12) * (nx - 1)).round().astype(int), 0, nx - 1)
        iy = np.clip(((coords_xy[:, 1] - ymin) / max(ymax - ymin, 1e-12) * (ny - 1)).round().astype(int), 0, ny - 1)
        Z = np.full((ny, nx), np.nan, dtype=np.float32)
        Z[iy, ix] = values.astype(np.float32)
        # nearest fill rows/cols
        for y in range(ny):
            row = Z[y]
            good = ~np.isnan(row)
            if np.any(good):
                idx = np.where(good)[0]
                for x in range(nx):
                    if np.isnan(row[x]):
                        row[x] = row[idx[np.argmin(np.abs(idx - x))]]
        for x in range(nx):
            col = Z[:, x]
            good = ~np.isnan(col)
            if np.any(good):
                idx = np.where(good)[0]
                for y in range(ny):
                    if np.isnan(col[y]):
                        col[y] = col[idx[np.argmin(np.abs(idx - y))]]
        return Z
    else:
        Z = griddata(coords_xy, values, (X, Y), method='linear')  # type: ignore
        Znn = griddata(coords_xy, values, (X, Y), method='nearest')  # type: ignore
        Z[np.isnan(Z)] = Znn[np.isnan(Z)]
        return Z


def _rrmse_magnitude(u_true: np.ndarray, u_pred: np.ndarray) -> float:
    """Relative RMSE computed on vector magnitude |U|.
    u_true/u_pred: (N,2)
    """
    mag_true = np.linalg.norm(u_true, axis=-1).astype(np.float64)
    mag_pred = np.linalg.norm(u_pred, axis=-1).astype(np.float64)
    diff = mag_pred - mag_true
    rmse = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean(mag_true**2)))
    return rmse / denom if denom != 0.0 else 0.0


def obs_fraction_for_count(batch, obs_count: int, fluid_only=True) -> float:
    gs = batch['graph_structures']
    g = gs if isinstance(gs, dict) else gs[0]
    z_field = batch.get('z', 35.0)
    z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
    gc = GraphCacheManager.get(g, z_value, fluid_only=fluid_only)
    denom = max(1, gc.fluid_cand.size)
    return float(np.clip(obs_count / denom, 1.0/denom, 1.0))


def _get_zoom_bounds(coords: np.ndarray, zoom_factor: float = 2.0) -> Tuple[float, float, float, float]:
    """
    Calculate zoom bounds centered on the domain center.
    """
    xmin, ymin = coords.min(axis=0)
    xmax, ymax = coords.max(axis=0)
    cx = (xmin + xmax) / 2
    cy = (ymin + ymax) / 2
    dx = (xmax - xmin) / zoom_factor
    dy = (ymax - ymin) / zoom_factor
    return cx - dx/2, cx + dx/2, cy - dy/2, cy + dy/2


def main():
    parser = argparse.ArgumentParser(description="CFG Guidance Scale (Gamma) sweep for reconstruction with GenSynth")
    parser.add_argument('--load_checkpoint', type=str, required=True, help='GenSynth checkpoint dir')
    parser.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k')
    parser.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--latent_size', type=int, default=64)
    parser.add_argument('--eval_angle_stride', type=int, default=36)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--output_dir', type=str, default='eval_outputs_reconstruction/cfg_sweep')
    parser.add_argument('--case_angle', type=int, default=145)
    parser.add_argument('--case_z', type=int, default=35)
    
    # CFG Sweep Params
    parser.add_argument('--cfg_min', type=float, default=1.0)
    parser.add_argument('--cfg_max', type=float, default=5.0)
    parser.add_argument('--cfg_points', type=int, default=5, help='Number of gamma samples')
    parser.add_argument('--cfg_list', type=str, default=None, help='Comma-separated list of gamma values (overrides min/max/points)')
    
    # Fixed Observation Params
    parser.add_argument('--obs_count', type=int, default=300, help='Fixed number of observations')
    parser.add_argument('--obs_neighbor_hops', type=int, default=0)
    parser.add_argument('--obs_mode', type=str, default='random', choices=['random','swarm'])
    
    # Grid plotting options
    parser.add_argument('--grid_nx', type=int, default=300)
    parser.add_argument('--grid_ny', type=int, default=300)
    args = parser.parse_args()

    out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)

    # Build test dataset
    test_ds = make_dataset(
        slice_root=args.slice_root,
        norm_stats_nc=args.norm_stats,
        batch_size=max(1, int(args.batch_size)),
        shuffle=False,
        seed=args.seed + 2,
        is_training=False,
        angle_stride=max(1, int(args.eval_angle_stride)),
        drop_remainder=True,
    )
    test_batches = list(test_ds)
    if len(test_batches) == 0:
        raise RuntimeError("Empty test dataset")

    # Init model
    example_batch = test_batches[0]
    model = build_model(example_batch, args.latent_size, seed=args.seed)
    model = restore_model(model, Path(args.load_checkpoint))

    # Determine gamma values
    if args.cfg_list:
        cfg_values = [float(x) for x in args.cfg_list.split(',')]
    else:
        pts = max(1, int(args.cfg_points))
        if pts == 1:
            cfg_values = [float(args.cfg_min)]
        else:
            cfg_values = np.linspace(float(args.cfg_min), float(args.cfg_max), num=pts).tolist()

    # Metric accumulators
    results = {k: {'rrmse': [], 'ssim': [], 'cosine': [], 'times': []} for k in cfg_values}

    # Main evaluation loop
    fixed_obs_count = int(args.obs_count)
    
    for gamma in cfg_values:
        print(f"Evaluating with CFG Gamma = {gamma:.2f}")
        
        # Update model sampler config
        # Assuming SamplerConfig is a dataclass (either standard or chex)
        # We replace the guidance_scale.
        current_cfg = model._sampler.cfg
        new_cfg = dataclasses.replace(current_cfg, guidance_scale=gamma)
        model._sampler.cfg = new_cfg
        
        for batch in test_batches:
            f = obs_fraction_for_count(batch, fixed_obs_count, fluid_only=True)
            inputs, forcings, *_ = prepare_batch(
                batch,
                obs_frac=f,
                obs_frac_min=f,
                obs_frac_max=f,
                obs_neighbor_hops=int(args.obs_neighbor_hops),
                obs_on_fluid_only=True,
                obs_seed=args.seed,
                obs_mode=args.obs_mode,
            )
            noise_shape = inputs['U_field'].shape
            rng_key = jax.random.PRNGKey(args.seed)
            noisy_inputs = jax.random.normal(rng_key, noise_shape)
            forcings['boundary_values'] = inputs['U_field']
            
            start = time.time()
            try:
                pred = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forcings)
            except Exception as e:
                print(f"Sampling failed: {e}")
                continue
            elapsed = time.time() - start

            original_data = np.array(inputs['U_field'][0])
            predicted_data = np.array(pred[0])
            diff = predicted_data - original_data
            rmse_error = float(np.sqrt(np.mean(diff**2)))
            denom = float(np.sqrt(np.mean(original_data**2)))
            rel_rmse_error = rmse_error / denom if denom != 0 else 0.0
            gs = forcings['graph_structures']
            if isinstance(gs, (list, tuple)):
                gs = gs[0]
            senders = np.asarray(gs['o2o_senders']).astype(np.int64)
            receivers = np.asarray(gs['o2o_receivers']).astype(np.int64)
            speed_true = np.linalg.norm(original_data, axis=-1).astype(np.float64)
            speed_pred = np.linalg.norm(predicted_data, axis=-1).astype(np.float64)
            gssim = _graph_ssim_speed(speed_true, speed_pred, senders, receivers, include_self=True)
            mean_ang_sim = _mean_angular_similarity(original_data, predicted_data)
            
            results[gamma]['rrmse'].append(rel_rmse_error)
            results[gamma]['ssim'].append(gssim)
            results[gamma]['cosine'].append(mean_ang_sim)
            results[gamma]['times'].append(elapsed)

    # Save metrics and times as CSV
    rrmse_means = [np.mean(results[c]['rrmse']) for c in cfg_values]
    ssim_means = [np.mean(results[c]['ssim']) for c in cfg_values]
    cosine_means = [np.mean(results[c]['cosine']) for c in cfg_values]
    time_means = [np.mean(results[c]['times']) for c in cfg_values]
    
    csv_path = Path(args.output_dir) / 'metrics_vs_cfg.csv'
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['gamma', 'mean_rrmse', 'mean_ssim', 'mean_cosine', 'mean_time'])
        for i, c in enumerate(cfg_values):
            writer.writerow([c, rrmse_means[i], ssim_means[i], cosine_means[i], time_means[i]])
    print(f"Saved metrics CSV to {csv_path}")

    # Plot metrics vs gamma
    plt.figure(figsize=(8,6))
    plt.plot(cfg_values, rrmse_means, marker='o', label='Mean R-RMSE', linewidth=2)
    plt.plot(cfg_values, ssim_means, marker='o', label='Mean SSIM', linewidth=2)
    plt.plot(cfg_values, cosine_means, marker='o', label='Mean Cosine Similarity', linewidth=2)
    plt.xlabel('CFG Gamma')
    plt.ylabel('Metric Value')
    plt.legend(); plt.grid(True)
    plt.title(f'Metrics vs CFG Scale (Obs count={fixed_obs_count})')
    plt.tight_layout()
    plt.savefig(str(Path(args.output_dir) / 'metrics_vs_cfg.png'), dpi=150)
    plt.close()

    # GT vs generations for selected case
    # Find matching batch
    found_batch = None
    for batch in test_batches:
        if int(batch['angle_deg'][0]) == int(args.case_angle) and int(batch.get('z', args.case_z)) == int(args.case_z):
            found_batch = batch
            break
    if found_batch is None:
        print(f"Requested case angle={args.case_angle}, z={args.case_z} not in test set; skipping panel plot.")
        return

    # Prepare coords for scatter
    coords = None
    try:
        from pyvista import read as pv_read
        case_name = f"case_{int(args.case_angle)}"
        slice_file = Path(args.slice_root) / case_name / f"slice_z_{int(args.case_z)}.vtu"
        slc = pv_read(str(slice_file))
        coords = slc.points[:, :2]
    except Exception as e:
        print(f"Warning: Could not load coordinates: {e}")
        # fallback
        example_inputs, example_forcings, *_ = prepare_batch(found_batch)
        coords = np.zeros((example_inputs['U_field'].shape[1], 2), dtype=np.float32)

    # Build row-wise panel: one row per gamma (including GT at row 0), three cols: |U|, Obs, Error
    num_rows = len(cfg_values) + 1
    fig, axes = plt.subplots(num_rows, 3, figsize=(12, 2.8*num_rows), squeeze=False, sharex=True, sharey=True)

    # Calculate zoom bounds
    z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

    # Prepare GT 
    inputs_gt, forcings_gt, *_ = prepare_batch(found_batch, obs_frac=0.0, obs_frac_min=0.0, obs_frac_max=0.0,
                                               obs_neighbor_hops=int(args.obs_neighbor_hops), obs_on_fluid_only=True,
                                               obs_seed=args.seed, obs_mode=args.obs_mode)
    
    # If coords couldn't be loaded earlier, fall back to graph struct
    if coords is None or coords.shape[0] == 0:
        coords = _coords_from_graph_struct(forcings_gt['graph_structures'])
        
    original_data = np.array(inputs_gt['U_field'][0])
    original_mag = np.linalg.norm(original_data, axis=-1)

    # Pre-compute all predictions to determine global vmin/vmax for |U|
    predictions_cache = []
    all_mags = [original_mag]
    
    # Obs fraction for the single fixed count
    f = obs_fraction_for_count(found_batch, fixed_obs_count, fluid_only=True)
    inputs, forcings, *_ = prepare_batch(
        found_batch,
        obs_frac=f, obs_frac_min=f, obs_frac_max=f,
        obs_neighbor_hops=int(args.obs_neighbor_hops),
        obs_on_fluid_only=True,
        obs_seed=args.seed,
        obs_mode=args.obs_mode,
    )
    # Obs mask is same for all gammas
    obs_mask = np.array(forcings.get('obs_mask'))[0] if 'obs_mask' in forcings else None
    
    noise_shape = inputs['U_field'].shape
    rng_key = jax.random.PRNGKey(args.seed)
    noisy_inputs = jax.random.normal(rng_key, noise_shape)
    forcings['boundary_values'] = inputs['U_field']

    for i, gamma in enumerate(cfg_values):
        # Update config
        current_cfg = model._sampler.cfg
        new_cfg = dataclasses.replace(current_cfg, guidance_scale=gamma)
        model._sampler.cfg = new_cfg
        
        pred = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forcings)
        predicted_data = np.array(pred[0])
        predictions_cache.append(predicted_data)
        all_mags.append(np.linalg.norm(predicted_data, axis=-1))
    
    vmin_mag = float(min([m.min() for m in all_mags]))
    vmax_mag = float(max([m.max() for m in all_mags]))

    # Row 0: GT
    sc = axes[0, 0].scatter(coords[:,0], coords[:,1], c=original_mag, cmap='RdBu_r', s=1.0, vmin=vmin_mag, vmax=vmax_mag)
    axes[0, 0].set_ylabel('GT')
    axes[0, 0].set_title('|U|')
    axes[0, 0].set_aspect('equal')
    axes[0, 0].set_xlim(z_xmin, z_xmax)
    axes[0, 0].set_ylim(z_ymin, z_ymax)
    axes[0, 0].set_xticks([]); axes[0, 0].set_yticks([])
    plt.colorbar(sc, ax=axes[0, 0], fraction=0.046, pad=0.02)
    
    # Obs plot (same for all, just show once or in GT row)
    axes[0, 1].axis('off')
    if obs_mask is not None:
        cmap_mask = mcolors.ListedColormap(['lightgray', 'red'])
        vals = (obs_mask > 0.5).astype(int)
        scm = axes[0, 1].scatter(coords[:,0], coords[:,1], c=vals, cmap=cmap_mask, vmin=0, vmax=1, s=0.2, alpha=0.9)
        axes[0, 1].axis('on')
        axes[0, 1].set_title(f'Obs (N={fixed_obs_count})')
        axes[0, 1].set_aspect('equal')
        axes[0, 1].set_xlim(z_xmin, z_xmax)
        axes[0, 1].set_ylim(z_ymin, z_ymax)
        axes[0, 1].set_xticks([]); axes[0, 1].set_yticks([])
        cb = plt.colorbar(scm, ax=axes[0, 1], fraction=0.046, pad=0.02)
        cb.set_ticks([0, 1]); cb.set_ticklabels(['unobs', 'obs'])
    
    axes[0, 2].axis('off')
    axes[0, 2].set_title('Rel. Error')

    # Rows 1..N: Generations
    for i, gamma in enumerate(cfg_values, start=1):
        predicted_data = predictions_cache[i-1]
        pred_mag = np.linalg.norm(predicted_data, axis=-1)
        
        sc = axes[i, 0].scatter(coords[:,0], coords[:,1], c=pred_mag, cmap='RdBu_r', s=1.0, vmin=vmin_mag, vmax=vmax_mag)
        axes[i, 0].set_ylabel(f'γ={gamma:.1f}')
        axes[i, 0].set_aspect('equal')
        axes[i, 0].set_xlim(z_xmin, z_xmax)
        axes[i, 0].set_ylim(z_ymin, z_ymax)
        axes[i, 0].set_xticks([]); axes[i, 0].set_yticks([])
        plt.colorbar(sc, ax=axes[i, 0], fraction=0.046, pad=0.02)
        
        # R-RMSE annotation
        try:
            rrmse_mag = _rrmse_magnitude(original_data, predicted_data)
            axes[i, 0].text(
                0.02, 0.95, f"R-RMSE={rrmse_mag:.3f}", transform=axes[i, 0].transAxes,
                ha='left', va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.6, linewidth=0)
            )
        except Exception:
            pass
        
        # Middle column: empty or maybe something else?
        # Let's leave it blank or maybe put SSIM?
        axes[i, 1].axis('off')
        
        # Error map
        diff_mag = np.linalg.norm(predicted_data - original_data, axis=-1)
        gt_mag = np.linalg.norm(original_data, axis=-1)
        rel_error = diff_mag / (gt_mag + 1e-6)
        sce = axes[i, 2].scatter(coords[:,0], coords[:,1], c=rel_error, cmap='Reds', vmin=0, vmax=1, s=1.0)
        axes[i, 2].set_aspect('equal')
        axes[i, 2].set_xlim(z_xmin, z_xmax)
        axes[i, 2].set_ylim(z_ymin, z_ymax)
        axes[i, 2].set_xticks([]); axes[i, 2].set_yticks([])
        plt.colorbar(sce, ax=axes[i, 2], fraction=0.046, pad=0.02)

    plt.tight_layout()
    panel_path = Path(args.output_dir) / f'gt_and_generations_vs_cfg_rows_scatter_angle{args.case_angle}_z{args.case_z}.png'
    plt.savefig(str(panel_path), dpi=150)
    plt.close()
    print(f"Saved row-wise scatter panel to {panel_path}")

    # Grid based plot (imshow)
    X, Y, (xmin, xmax, ymin, ymax) = _make_grid(coords, nx=int(args.grid_nx), ny=int(args.grid_ny))
    extent = (xmin, xmax, ymin, ymax)
    fig, axes = plt.subplots(num_rows, 3, figsize=(12, 2.8*num_rows), squeeze=False, sharex=False, sharey=False)

    # GT
    Z = _scatter_to_grid(coords, original_mag, X, Y)
    im = axes[0, 0].imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin_mag, vmax=vmax_mag)
    axes[0, 0].set_title('|U|'); axes[0, 0].set_ylabel('GT')
    axes[0, 0].set_aspect('equal')
    axes[0, 0].set_xlim(z_xmin, z_xmax)
    axes[0, 0].set_ylim(z_ymin, z_ymax)
    axes[0, 0].set_xticks([]); axes[0, 0].set_yticks([])
    plt.colorbar(im, ax=axes[0, 0], fraction=0.046, pad=0.02)
    
    # Obs mask grid
    axes[0, 1].axis('off')
    if obs_mask is not None:
        vals = (obs_mask > 0.5).astype(np.float32)
        M = _scatter_to_grid(coords, vals, X, Y)
        axes[0, 1].axis('on')
        im2 = axes[0, 1].imshow(M, origin='lower', extent=extent, cmap=mcolors.ListedColormap(['lightgray','red']), vmin=0, vmax=1)
        axes[0, 1].set_title(f'Obs (N={fixed_obs_count})')
        axes[0, 1].set_aspect('equal')
        axes[0, 1].set_xlim(z_xmin, z_xmax)
        axes[0, 1].set_ylim(z_ymin, z_ymax)
        axes[0, 1].set_xticks([]); axes[0, 1].set_yticks([])
        cb = plt.colorbar(im2, ax=axes[0, 1], fraction=0.046, pad=0.02); cb.set_ticks([0,1]); cb.set_ticklabels(['unobs','obs'])

    axes[0, 2].axis('off')
    axes[0, 2].set_title('Rel. Error')

    for i, gamma in enumerate(cfg_values, start=1):
        predicted_data = predictions_cache[i-1]
        pred_mag = np.linalg.norm(predicted_data, axis=-1)
        Z = _scatter_to_grid(coords, pred_mag, X, Y)
        im = axes[i, 0].imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin_mag, vmax=vmax_mag)
        axes[i, 0].set_ylabel(f'γ={gamma:.1f}')
        axes[i, 0].set_aspect('equal')
        axes[i, 0].set_xlim(z_xmin, z_xmax)
        axes[i, 0].set_ylim(z_ymin, z_ymax)
        axes[i, 0].set_xticks([]); axes[i, 0].set_yticks([])
        plt.colorbar(im, ax=axes[i, 0], fraction=0.046, pad=0.02)
        
        try:
            rrmse_mag = _rrmse_magnitude(original_data, predicted_data)
            axes[i, 0].text(
                0.02, 0.95, f"R-RMSE={rrmse_mag:.3f}", transform=axes[i, 0].transAxes,
                ha='left', va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.6, linewidth=0)
            )
        except Exception:
            pass
            
        axes[i, 1].axis('off')
        
        diff_mag = np.linalg.norm(predicted_data - original_data, axis=-1)
        gt_mag = np.linalg.norm(original_data, axis=-1)
        rel_error = diff_mag / (gt_mag + 1e-6)
        Z_err = _scatter_to_grid(coords, rel_error, X, Y)
        im3 = axes[i, 2].imshow(Z_err, origin='lower', extent=extent, cmap='Reds', vmin=0, vmax=1)
        axes[i, 2].set_aspect('equal')
        axes[i, 2].set_xlim(z_xmin, z_xmax)
        axes[i, 2].set_ylim(z_ymin, z_ymax)
        axes[i, 2].set_xticks([]); axes[i, 2].set_yticks([])
        plt.colorbar(im3, ax=axes[i, 2], fraction=0.046, pad=0.02)

    plt.tight_layout()
    panel_path_grid = Path(args.output_dir) / f'gt_and_generations_vs_cfg_rows_grid_angle{args.case_angle}_z{args.case_z}.png'
    plt.savefig(str(panel_path_grid), dpi=150)
    plt.close()
    print(f"Saved row-wise grid panel to {panel_path_grid}")


if __name__ == '__main__':
    main()
