#!/usr/bin/env python3
"""
Sweep observation counts for GenDA and report reconstruction metrics, timings, and example plots.
"""
import os
import argparse
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
import time
import csv
from typing import Tuple
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax.experimental import mesh_utils

from gen_da.gen_da import GenSynth, SamplerConfig, NoiseConfig
from gen_da.denoiser import (
    DenoiserArchitectureConfig,
    NoiseEncoderConfig,
    AngleEncoderConfig,
)

from training.graph_dataset import make_dataset
from training.train import _mean_angular_similarity, _graph_ssim_speed
from training.obs_sampling import GraphCacheManager

try:
    from training.train import prepare_batch
except Exception:
    from baselines.train_meshgraphnets import prepare_batch  # type: ignore


def build_model(example_batch, latent_size: int, seed: int = 0):
    example_inputs, example_forcings, *_ = prepare_batch(example_batch)
    n_devices = max(1, jax.device_count())
    mesh = jax.sharding.Mesh(mesh_utils.create_device_mesh((n_devices,)), ('data',))
    den_cfg = DenoiserArchitectureConfig(
        latent_size=latent_size,
    )
    example_graph = example_forcings['graph_structures']
    if isinstance(example_graph, (list, tuple)):
        example_graph = example_graph[0]
    target_channels = 2
    guiding_channels = 2
    model = GenSynth(
        denoiser_architecture_config=den_cfg,
        sampler_config=SamplerConfig(),
        noise_config=NoiseConfig(),
        noise_encoder_config=NoiseEncoderConfig(),
        angle_encoder_config=AngleEncoderConfig(),
        rngs=nnx.Rngs(seed),
        mesh=mesh,
        example_graph_structures=example_graph,
        target_channels=target_channels,
        guiding_channels=guiding_channels,
    )
    return model


def restore_model(model, ckpt_dir: Path):
    import orbax.checkpoint as ocp
    abs_graphdef, abs_rng_state, abs_other_state = nnx.split(model, nnx.RngState, ...)
    ckptr = ocp.PyTreeCheckpointer()
    restored_state = ckptr.restore(str(ckpt_dir), item=abs_other_state)
    nnx.update(model, restored_state)
    model.eval()
    return model


# ----------------------------
# Coords and grid helpers
# ----------------------------

def _coords_from_graph_struct(gs) -> np.ndarray:
    if isinstance(gs, (list, tuple)):
        gs = gs[0]
    try:
        coords = np.asarray(gs.get('original_coordinates'))
        if coords is not None:
            return coords[:, :2] if coords.shape[1] >= 2 else np.pad(coords, ((0,0),(0,max(0,2-coords.shape[1]))))[:, :2]
    except Exception:
        pass
    return np.zeros((int(gs.get('num_o_nodes', 0)), 2), dtype=np.float32)


try:
    from scipy.interpolate import griddata  # type: ignore
except Exception:
    griddata = None  # type: ignore


def _make_grid(coords_xy: np.ndarray, nx: int = 300, ny: int = 300):
    xmin, ymin = coords_xy.min(axis=0)
    xmax, ymax = coords_xy.max(axis=0)
    gx = np.linspace(xmin, xmax, nx)
    gy = np.linspace(ymin, ymax, ny)
    X, Y = np.meshgrid(gx, gy)
    return X, Y, (xmin, xmax, ymin, ymax)


def _scatter_to_grid(coords_xy: np.ndarray, values: np.ndarray, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
    if griddata is None:
        xmin, xmax = X[0, 0], X[0, -1]
        ymin, ymax = Y[0, 0], Y[-1, 0]
        ny, nx = X.shape
        ix = np.clip(((coords_xy[:, 0] - xmin) / max(xmax - xmin, 1e-12) * (nx - 1)).round().astype(int), 0, nx - 1)
        iy = np.clip(((coords_xy[:, 1] - ymin) / max(ymax - ymin, 1e-12) * (ny - 1)).round().astype(int), 0, ny - 1)
        Z = np.full((ny, nx), np.nan, dtype=np.float32)
        Z[iy, ix] = values.astype(np.float32)
        # nearest fill rows/cols
        for y in range(ny):
            row = Z[y]
            good = ~np.isnan(row)
            if np.any(good):
                idx = np.where(good)[0]
                for x in range(nx):
                    if np.isnan(row[x]):
                        row[x] = row[idx[np.argmin(np.abs(idx - x))]]
        for x in range(nx):
            col = Z[:, x]
            good = ~np.isnan(col)
            if np.any(good):
                idx = np.where(good)[0]
                for y in range(ny):
                    if np.isnan(col[y]):
                        col[y] = col[idx[np.argmin(np.abs(idx - y))]]
        return Z
    else:
        Z = griddata(coords_xy, values, (X, Y), method='linear')  # type: ignore
        Znn = griddata(coords_xy, values, (X, Y), method='nearest')  # type: ignore
        Z[np.isnan(Z)] = Znn[np.isnan(Z)]
        return Z


def _rrmse_magnitude(u_true: np.ndarray, u_pred: np.ndarray) -> float:
    """Relative RMSE computed on vector magnitude |U|.
    u_true/u_pred: (N,2)
    """
    mag_true = np.linalg.norm(u_true, axis=-1).astype(np.float64)
    mag_pred = np.linalg.norm(u_pred, axis=-1).astype(np.float64)
    diff = mag_pred - mag_true
    rmse = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean(mag_true**2)))
    return rmse / denom if denom != 0.0 else 0.0


def obs_fraction_for_count(batch, obs_count: int, fluid_only=True) -> float:
    gs = batch['graph_structures']
    g = gs if isinstance(gs, dict) else gs[0]
    z_field = batch.get('z', 35.0)
    z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
    gc = GraphCacheManager.get(g, z_value, fluid_only=fluid_only)
    denom = max(1, gc.fluid_cand.size)
    return float(np.clip(obs_count / denom, 1.0/denom, 1.0))


def _get_zoom_bounds(coords: np.ndarray, zoom_factor: float = 2.0) -> Tuple[float, float, float, float]:
    """
    Calculate zoom bounds centered on the domain center.
    """
    xmin, ymin = coords.min(axis=0)
    xmax, ymax = coords.max(axis=0)
    cx = (xmin + xmax) / 2
    cy = (ymin + ymax) / 2
    dx = (xmax - xmin) / zoom_factor
    dy = (ymax - ymin) / zoom_factor
    return cx - dx/2, cx + dx/2, cy - dy/2, cy + dy/2


def main():
    parser = argparse.ArgumentParser(description="Observation sweep for reconstruction with GenSynth")
    parser.add_argument('--load_checkpoint', type=str, required=True, help='GenSynth checkpoint dir')
    parser.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k')
    parser.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--latent_size', type=int, default=64)
    parser.add_argument('--eval_angle_stride', type=int, default=36)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--output_dir', type=str, default='eval_outputs/observation_sweep')
    parser.add_argument('--case_angle', type=int, default=145)
    parser.add_argument('--case_z', type=int, default=35)
    parser.add_argument('--obs_min', type=int, default=100)
    parser.add_argument('--obs_max', type=int, default=15000)
    parser.add_argument('--obs_points', type=int, default=10, help='Number of observation-count samples')
    parser.add_argument('--obs_neighbor_hops', type=int, default=0)
    parser.add_argument('--obs_mode', type=str, default='random', choices=['random','swarm'])
    # Grid plotting options
    parser.add_argument('--grid_nx', type=int, default=300)
    parser.add_argument('--grid_ny', type=int, default=300)
    args = parser.parse_args()

    out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)

    # Build test dataset
    test_ds = make_dataset(
        slice_root=args.slice_root,
        norm_stats_nc=args.norm_stats,
        batch_size=max(1, int(args.batch_size)),
        shuffle=False,
        seed=args.seed + 2,
        is_training=False,
        angle_stride=max(1, int(args.eval_angle_stride)),
        drop_remainder=True,
    )
    test_batches = list(test_ds)
    if len(test_batches) == 0:
        raise RuntimeError("Empty test dataset")

    # Init model
    example_batch = test_batches[0]
    model = build_model(example_batch, args.latent_size, seed=args.seed)
    model = restore_model(model, Path(args.load_checkpoint))

    # Determine observation counts (<=10, spaced)
    obs_points = max(1, min(int(args.obs_points), 10))
    if obs_points == 1:
        obs_counts = [int(args.obs_max)]
    else:
        obs_counts = np.linspace(int(args.obs_min), int(args.obs_max), num=obs_points, dtype=int).tolist()

    # Metric accumulators
    results = {k: {'rrmse': [], 'ssim': [], 'cosine': [], 'times': []} for k in obs_counts}

    # Main evaluation loop
    for obs_k in obs_counts:
        print(f"Evaluating with ~{obs_k} observations")
        for batch in test_batches:
            f = obs_fraction_for_count(batch, obs_k, fluid_only=True)
            inputs, forcings, *_ = prepare_batch(
                batch,
                obs_frac=f,
                obs_frac_min=f,
                obs_frac_max=f,
                obs_neighbor_hops=int(args.obs_neighbor_hops),
                obs_on_fluid_only=True,
                obs_seed=args.seed,
                obs_mode=args.obs_mode,
            )
            noise_shape = inputs['U_field'].shape
            rng_key = jax.random.PRNGKey(args.seed)
            noisy_inputs = jax.random.normal(rng_key, noise_shape)
            forcings['boundary_values'] = inputs['U_field']
            start = time.time()
            try:
                pred = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forcings)
            except Exception as e:
                print(f"Sampling failed: {e}")
                continue
            elapsed = time.time() - start

            original_data = np.array(inputs['U_field'][0])
            predicted_data = np.array(pred[0])
            diff = predicted_data - original_data
            rmse_error = float(np.sqrt(np.mean(diff**2)))
            denom = float(np.sqrt(np.mean(original_data**2)))
            rel_rmse_error = rmse_error / denom if denom != 0 else 0.0
            gs = forcings['graph_structures']
            if isinstance(gs, (list, tuple)):
                gs = gs[0]
            senders = np.asarray(gs['o2o_senders']).astype(np.int64)
            receivers = np.asarray(gs['o2o_receivers']).astype(np.int64)
            speed_true = np.linalg.norm(original_data, axis=-1).astype(np.float64)
            speed_pred = np.linalg.norm(predicted_data, axis=-1).astype(np.float64)
            gssim = _graph_ssim_speed(speed_true, speed_pred, senders, receivers, include_self=True)
            mean_ang_sim = _mean_angular_similarity(original_data, predicted_data)
            results[obs_k]['rrmse'].append(rel_rmse_error)
            results[obs_k]['ssim'].append(gssim)
            results[obs_k]['cosine'].append(mean_ang_sim)
            results[obs_k]['times'].append(elapsed)

    # Save metrics and times as CSV
    counts = obs_counts
    rrmse_means = [np.mean(results[c]['rrmse']) for c in counts]
    ssim_means = [np.mean(results[c]['ssim']) for c in counts]
    cosine_means = [np.mean(results[c]['cosine']) for c in counts]
    time_means = [np.mean(results[c]['times']) for c in counts]
    csv_path = Path(args.output_dir) / 'metrics_vs_obscount.csv'
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['obs_count', 'mean_rrmse', 'mean_ssim', 'mean_cosine', 'mean_time'])
        for i, c in enumerate(counts):
            writer.writerow([c, rrmse_means[i], ssim_means[i], cosine_means[i], time_means[i]])
    print(f"Saved metrics CSV to {csv_path}")

    # Plot metrics vs obs count
    plt.figure(figsize=(8,6))
    plt.plot(counts, rrmse_means, marker='o', label='Mean R-RMSE', linewidth=2)
    plt.plot(counts, ssim_means, marker='o', label='Mean SSIM', linewidth=2)
    plt.plot(counts, cosine_means, marker='o', label='Mean Cosine Similarity', linewidth=2)
    plt.xlabel('Observation count')
    plt.ylabel('Metric Value')
    plt.legend(); plt.grid(True)
    plt.tight_layout()
    plt.savefig(str(Path(args.output_dir) / 'metrics_vs_obscount.png'), dpi=150)
    plt.close()

    # Plot mean time vs obs count
    plt.figure(figsize=(8,6))
    plt.plot(counts, time_means, marker='o', color='purple', label='Mean Time (s)')
    plt.xlabel('Observation count'); plt.ylabel('Mean Time (seconds)')
    plt.legend(); plt.grid(True)
    plt.tight_layout()
    plt.savefig(str(Path(args.output_dir) / 'mean_time_vs_obscount.png'), dpi=150)
    plt.close()

    # GT vs generations for selected case and obs counts (two-row panel)
    # Find matching batch
    found_batch = None
    for batch in test_batches:
        if int(batch['angle_deg'][0]) == int(args.case_angle) and int(batch.get('z', args.case_z)) == int(args.case_z):
            found_batch = batch
            break
    if found_batch is None:
        print(f"Requested case angle={args.case_angle}, z={args.case_z} not in test set; skipping panel plot.")
        return

    # Prepare coords for scatter
    coords = None
    try:
        from pyvista import read as pv_read
        case_name = f"case_{int(args.case_angle)}"
        slice_file = Path(args.slice_root) / case_name / f"slice_z_{int(args.case_z)}.vtu"
        slc = pv_read(str(slice_file))
        coords = slc.points[:, :2]
    except Exception as e:
        print(f"Warning: Could not load coordinates: {e}")
        # fallback
        example_inputs, example_forcings, *_ = prepare_batch(found_batch)
        coords = np.zeros((example_inputs['U_field'].shape[1], 2), dtype=np.float32)

    # Build row-wise panel: one row per count (including GT at row 0), three cols: |U|, Obs, Error
    num_rows = len(obs_counts) + 1
    fig, axes = plt.subplots(num_rows, 3, figsize=(12, 2.8*num_rows), squeeze=False, sharex=True, sharey=True)

    # Calculate zoom bounds
    z_xmin, z_xmax, z_ymin, z_ymax = _get_zoom_bounds(coords, zoom_factor=2.0)

    # Prepare GT without observations
    inputs_gt, forcings_gt, *_ = prepare_batch(found_batch, obs_frac=0.0, obs_frac_min=0.0, obs_frac_max=0.0,
                                               obs_neighbor_hops=int(args.obs_neighbor_hops), obs_on_fluid_only=True,
                                               obs_seed=args.seed, obs_mode=args.obs_mode)
    # If coords couldn't be loaded earlier, fall back to graph struct
    if coords is None or coords.shape[0] == 0:
        coords = _coords_from_graph_struct(forcings_gt['graph_structures'])
    original_data = np.array(inputs_gt['U_field'][0])
    original_mag = np.linalg.norm(original_data, axis=-1)

    # Pre-compute all predictions to determine global vmin/vmax for |U|
    # Store them to avoid re-running inference
    predictions_cache = []
    all_mags = [original_mag]
    
    for i, obs_k in enumerate(obs_counts):
        f = obs_fraction_for_count(found_batch, obs_k, fluid_only=True)
        inputs, forcings, *_ = prepare_batch(
            found_batch,
            obs_frac=f, obs_frac_min=f, obs_frac_max=f,
            obs_neighbor_hops=int(args.obs_neighbor_hops),
            obs_on_fluid_only=True,
            obs_seed=args.seed,
            obs_mode=args.obs_mode,
        )
        noise_shape = inputs['U_field'].shape
        rng_key = jax.random.PRNGKey(args.seed)
        noisy_inputs = jax.random.normal(rng_key, noise_shape)
        forcings['boundary_values'] = inputs['U_field']
        pred = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forcings)
        predicted_data = np.array(pred[0])
        predictions_cache.append((predicted_data, forcings))
        all_mags.append(np.linalg.norm(predicted_data, axis=-1))
    
    vmin_mag = float(min([m.min() for m in all_mags]))
    vmax_mag = float(max([m.max() for m in all_mags]))

    sc = axes[0, 0].scatter(coords[:,0], coords[:,1], c=original_mag, cmap='RdBu_r', s=1.0, vmin=vmin_mag, vmax=vmax_mag)
    axes[0, 0].set_ylabel('GT')
    axes[0, 0].set_title('|U|')
    axes[0, 0].set_aspect('equal')
    axes[0, 0].set_xlim(z_xmin, z_xmax)
    axes[0, 0].set_ylim(z_ymin, z_ymax)
    axes[0, 0].set_xticks([]); axes[0, 0].set_yticks([])
    plt.colorbar(sc, ax=axes[0, 0], fraction=0.046, pad=0.02)
    
    # Keep first row, second column empty as requested
    cmap_mask = mcolors.ListedColormap(['lightgray', 'red'])
    axes[0, 1].axis('off')
    axes[0, 1].set_title('Obs')
    
    # Keep first row, third column empty (Error)
    axes[0, 2].axis('off')
    axes[0, 2].set_title('Rel. Error')

    # Each subsequent row is a generation for a specific obs count
    for i, obs_k in enumerate(obs_counts, start=1):
        predicted_data, forcings = predictions_cache[i-1]
        pred_mag = np.linalg.norm(predicted_data, axis=-1)
        sc = axes[i, 0].scatter(coords[:,0], coords[:,1], c=pred_mag, cmap='RdBu_r', s=1.0, vmin=vmin_mag, vmax=vmax_mag)
        axes[i, 0].set_ylabel(f'obs≈{obs_k}')
        axes[i, 0].set_aspect('equal')
        axes[i, 0].set_xlim(z_xmin, z_xmax)
        axes[i, 0].set_ylim(z_ymin, z_ymax)
        axes[i, 0].set_xticks([]); axes[i, 0].set_yticks([])
        plt.colorbar(sc, ax=axes[i, 0], fraction=0.046, pad=0.02)
        # Annotate R-RMSE on |U|
        try:
            rrmse_mag = _rrmse_magnitude(original_data, predicted_data)
            axes[i, 0].text(
                0.02, 0.95, f"R-RMSE={rrmse_mag:.3f}", transform=axes[i, 0].transAxes,
                ha='left', va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.6, linewidth=0)
            )
        except Exception:
            pass
        # Observation map with colorbar
        obs_mask = np.array(forcings.get('obs_mask'))[0] if 'obs_mask' in forcings else None
        if obs_mask is None:
            vals = np.zeros(coords.shape[0], dtype=int)
        else:
            vals = (obs_mask > 0.5).astype(int)
        scm = axes[i, 1].scatter(coords[:,0], coords[:,1], c=vals, cmap=cmap_mask, vmin=0, vmax=1, s=0.2, alpha=0.9)
        axes[i, 1].set_aspect('equal')
        axes[i, 1].set_xlim(z_xmin, z_xmax)
        axes[i, 1].set_ylim(z_ymin, z_ymax)
        axes[i, 1].set_xticks([]); axes[i, 1].set_yticks([])
        cb = plt.colorbar(scm, ax=axes[i, 1], fraction=0.046, pad=0.02)
        cb.set_ticks([0, 1]); cb.set_ticklabels(['unobs', 'obs'])
        
        # Error map
        diff_mag = np.linalg.norm(predicted_data - original_data, axis=-1)
        gt_mag = np.linalg.norm(original_data, axis=-1)
        rel_error = diff_mag / (gt_mag + 1e-6)
        sce = axes[i, 2].scatter(coords[:,0], coords[:,1], c=rel_error, cmap='Reds', vmin=0, vmax=1, s=1.0)
        axes[i, 2].set_aspect('equal')
        axes[i, 2].set_xlim(z_xmin, z_xmax)
        axes[i, 2].set_ylim(z_ymin, z_ymax)
        axes[i, 2].set_xticks([]); axes[i, 2].set_yticks([])
        plt.colorbar(sce, ax=axes[i, 2], fraction=0.046, pad=0.02)

    plt.tight_layout()
    panel_path = Path(args.output_dir) / f'gt_and_generations_vs_obscount_rows_scatter_angle{args.case_angle}_z{args.case_z}.png'
    plt.savefig(str(panel_path), dpi=150)
    plt.close()
    print(f"Saved row-wise scatter panel to {panel_path}")

    # Also save a grid-based (imshow) version using interpolation
    X, Y, (xmin, xmax, ymin, ymax) = _make_grid(coords, nx=int(args.grid_nx), ny=int(args.grid_ny))
    extent = (xmin, xmax, ymin, ymax)
    fig, axes = plt.subplots(num_rows, 3, figsize=(12, 2.8*num_rows), squeeze=False, sharex=False, sharey=False)

    # GT row
    Z = _scatter_to_grid(coords, original_mag, X, Y)
    im = axes[0, 0].imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin_mag, vmax=vmax_mag)
    axes[0, 0].set_title('|U|'); axes[0, 0].set_ylabel('GT')
    axes[0, 0].set_aspect('equal')
    axes[0, 0].set_xlim(z_xmin, z_xmax)
    axes[0, 0].set_ylim(z_ymin, z_ymax)
    axes[0, 0].set_xticks([]); axes[0, 0].set_yticks([])
    plt.colorbar(im, ax=axes[0, 0], fraction=0.046, pad=0.02)
    # Keep first row, second column empty in the grid figure as well
    axes[0, 1].axis('off')
    axes[0, 1].set_title('Obs')
    # Keep first row, third column empty
    axes[0, 2].axis('off')
    axes[0, 2].set_title('Rel. Error')

    for i, obs_k in enumerate(obs_counts, start=1):
        predicted_data, forcings = predictions_cache[i-1]
        pred_mag = np.linalg.norm(predicted_data, axis=-1)
        Z = _scatter_to_grid(coords, pred_mag, X, Y)
        im = axes[i, 0].imshow(Z, origin='lower', extent=extent, cmap='RdBu_r', vmin=vmin_mag, vmax=vmax_mag)
        axes[i, 0].set_ylabel(f'obs≈{obs_k}')
        axes[i, 0].set_aspect('equal')
        axes[i, 0].set_xlim(z_xmin, z_xmax)
        axes[i, 0].set_ylim(z_ymin, z_ymax)
        axes[i, 0].set_xticks([]); axes[i, 0].set_yticks([])
        plt.colorbar(im, ax=axes[i, 0], fraction=0.046, pad=0.02)
        # Annotate R-RMSE on |U|
        try:
            rrmse_mag = _rrmse_magnitude(original_data, predicted_data)
            axes[i, 0].text(
                0.02, 0.95, f"R-RMSE={rrmse_mag:.3f}", transform=axes[i, 0].transAxes,
                ha='left', va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.6, linewidth=0)
            )
        except Exception:
            pass
        obs_mask = np.array(forcings.get('obs_mask'))[0] if 'obs_mask' in forcings else None
        vals = (obs_mask > 0.5).astype(np.float32) if obs_mask is not None else np.zeros(coords.shape[0], dtype=np.float32)
        M = _scatter_to_grid(coords, vals, X, Y)
        im2 = axes[i, 1].imshow(M, origin='lower', extent=extent, cmap=mcolors.ListedColormap(['lightgray','red']), vmin=0, vmax=1)
        axes[i, 1].set_aspect('equal')
        axes[i, 1].set_xlim(z_xmin, z_xmax)
        axes[i, 1].set_ylim(z_ymin, z_ymax)
        axes[i, 1].set_xticks([]); axes[i, 1].set_yticks([])
        cb = plt.colorbar(im2, ax=axes[i, 1], fraction=0.046, pad=0.02); cb.set_ticks([0,1]); cb.set_ticklabels(['unobs','obs'])
        
        # Error map
        diff_mag = np.linalg.norm(predicted_data - original_data, axis=-1)
        gt_mag = np.linalg.norm(original_data, axis=-1)
        rel_error = diff_mag / (gt_mag + 1e-6)
        Z_err = _scatter_to_grid(coords, rel_error, X, Y)
        im3 = axes[i, 2].imshow(Z_err, origin='lower', extent=extent, cmap='Reds', vmin=0, vmax=1)
        axes[i, 2].set_aspect('equal')
        axes[i, 2].set_xlim(z_xmin, z_xmax)
        axes[i, 2].set_ylim(z_ymin, z_ymax)
        axes[i, 2].set_xticks([]); axes[i, 2].set_yticks([])
        plt.colorbar(im3, ax=axes[i, 2], fraction=0.046, pad=0.02)

    plt.tight_layout()
    panel_path_grid = Path(args.output_dir) / f'gt_and_generations_vs_obscount_rows_grid_angle{args.case_angle}_z{args.case_z}.png'
    plt.savefig(str(panel_path_grid), dpi=150)
    plt.close()
    print(f"Saved row-wise grid panel to {panel_path_grid}")


if __name__ == '__main__':
    main()
