#!/usr/bin/env python3
"""
Training script for GenDA diffusion model.
"""

import os
# If you want to pin specific GPUs, uncomment and set as needed:
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import time
import argparse
import itertools
from pathlib import Path
from typing import Dict, Any

import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend for cluster
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
import jax
import jax.numpy as jnp
import numpy as np
import pyvista as pv
import optax
import orbax.checkpoint as ocp
import wandb
import flax.nnx as nnx
from jax.experimental import mesh_utils
import dataclasses

# JAX memory options
jax.config.update('jax_enable_x64', False)

# --- GenSynth / training imports ---
from gen_da.gen_da import GenSynth, SamplerConfig, NoiseConfig
from gen_da.denoiser import (
    DenoiserArchitectureConfig,
    NoiseEncoderConfig,
    AngleEncoderConfig,  # Updated to include rotation config
)
from training.graph_dataset import make_dataset
from training.obs_sampling import GraphCacheManager, sample_random_obs, DroneSwarmSampler

print("JAX sees these devices:", jax.devices())
print("Device count:", jax.device_count())


# ----------------------------
# Dataset + batching utilities
# ----------------------------
def create_datasets(
    slice_root,
    norm_stats,
    batch_size,
    seed,
    is_training=True,
    *,
    fixed_angle=None,
    fixed_z=None,
    shuffle=True,
    angle_stride=1,  # NEW: evenly spaced subset controller
):
    """
    Create a Grain MapDataset that yields batches of graph data.
    Each batch is a dict with keys like:
      - 'target_inputs'  (B, N_o, C_target)
      - 'guiding_inputs' (B, N_o, C_guiding)
      - 'case_number'    (B,)
      - 'graph_structures' (dict shared per batch or list of dicts)
      - 'z' (scalar/int for the slice depth)
    """
    ds = make_dataset(
        slice_root=slice_root,
        norm_stats_nc=norm_stats,
        batch_size=batch_size,
        shuffle=shuffle,
        seed=seed,
        is_training=is_training,
        fixed_angle=fixed_angle,
        fixed_z=fixed_z,
        drop_remainder=True,
        angle_stride=angle_stride,  # pass through
    )
    return ds


def prepare_batch(
    batch: Dict[str, Any],
    *,
    obs_frac=0.05,
    obs_frac_min: float | None = None,
    obs_frac_max: float | None = None,
    obs_neighbor_hops: int = 1,
    obs_on_fluid_only: bool = True,
    obs_seed: int = 1234,
    # NEW:
    obs_jax_key: jax.Array | None = None,
    focus_xy: tuple[float, float] = (0.1, 0.1),
    focus_boost: float = 8.0,
    focus_trigger_frac: float = 0.003,
    # Fast modes
    obs_mode: str = 'random',        # 'random' | 'swarm'
    swarm_num_drones: int = 10,
    swarm_hops_radius: int = 1,
    swarm_move_prob: float = 0.9,
    swarm_traj_len: int = 1,
    swarm_target_frac: float | None = None,
    swarm_min_frac: float | None = None,
):
    if obs_jax_key is not None:
        try:
            # JAX >= 0.4.14: PRNGKey is a struct; key_data gives uint32[2]
            kd = jax.random.key_data(obs_jax_key)         # shape (2,), dtype=uint32
            kd = np.asarray(jax.device_get(kd), dtype=np.uint64)
        except Exception:
            # Fallback for older JAX: sample two u32s
            kd = jax.random.randint(
                obs_jax_key, shape=(2,), minval=0, maxval=2**32, dtype=jnp.uint32
            )
            kd = np.asarray(jax.device_get(kd), dtype=np.uint64)

        # Combine two u32s into one u64 seed (portable, deterministic per key)
        seed64 = int((kd[0] << 32) | kd[1])
        rng = np.random.default_rng(seed64)
    else:
        rng = np.random.default_rng(obs_seed + int(batch['angle_deg'][0]))

    U_clean = np.array(batch['target_inputs'], dtype=np.float32)  # (B,N,2)

    # graph_structures (shared across batch)
    gs = batch['graph_structures']
    if isinstance(gs, dict):
        graph_struct = {k: np.array(v) for k, v in gs.items()}
        B = U_clean.shape[0]
    else:
        graph_struct = {k: np.array(v) for k, v in gs[0].items()}
        B = len(gs)

    coords_o = graph_struct["original_coordinates"]
    N = coords_o.shape[0]

    # Build/fetch per-z graph cache (CSR + candidates)
    z_field = batch.get('z', 35.0)
    if isinstance(z_field, (list, tuple, np.ndarray)):
        z_value = float(np.asarray(z_field).flatten()[0])
    else:
        z_value = float(z_field)
    gc = GraphCacheManager.get(graph_struct, z_value, fluid_only=obs_on_fluid_only)


    # ----- fraction for THIS batch -----
    # Curriculum: Anneal until lowest value, then sample uniformly from min to max
    # We assume caller passes current_step, obs_frac_anneal_steps, obs_frac_start, obs_frac_end
    # If not, fallback to old behavior
    current_step = batch.get('current_step', None)
    obs_frac_anneal_steps = batch.get('obs_frac_anneal_steps', None)
    obs_frac_start = batch.get('obs_frac_start', None)
    obs_frac_end = batch.get('obs_frac_end', None)

    if (
        current_step is not None and
        obs_frac_anneal_steps is not None and
        obs_frac_start is not None and
        obs_frac_end is not None and
        (obs_frac_min is not None) and (obs_frac_max is not None)
    ):
        # Anneal for first obs_frac_anneal_steps
        if current_step < obs_frac_anneal_steps:
            frac = obs_frac_start + (obs_frac_end - obs_frac_start) * min(current_step / obs_frac_anneal_steps, 1.0)
            f = float(np.clip(frac, obs_frac_min, obs_frac_max))
        else:
            # Uniformly sample from min to max
            f = float(rng.uniform(low=obs_frac_min, high=obs_frac_max))
    else:
        # Fallback: old behavior
        if (obs_frac_min is not None) and (obs_frac_max is not None):
            f = float(rng.uniform(low=obs_frac_min, high=obs_frac_max))
        else:
            f = float(obs_frac)
    f = float(np.clip(f, 0.0, 1.0))

    k = max(1, int(round(f * max(1, gc.fluid_cand.size))))

    # Decide whether to enable city-center focus
    enable_focus = (f <= float(focus_trigger_frac))

    # Precompute a per-node weight vector for base picks (only used if enable_focus)
    if enable_focus:
        xy = np.asarray(gc.coords, dtype=np.float64)[:, :2]
        in_box = (np.abs(xy[:, 0]) <= focus_xy[0]) & (np.abs(xy[:, 1]) <= focus_xy[1])
        base_weights = np.ones(N, dtype=np.float64)
        base_weights[in_box] *= float(focus_boost)
    else:
        base_weights = None  # unused

    # Hop expansion via cached CSR
    def expand_mask_to_hops(base_mask: np.ndarray, hops: int) -> np.ndarray:
        if hops <= 0:
            return base_mask.astype(np.float32)
        from training.obs_sampling import expand_mask_hops
        return expand_mask_hops(base_mask.astype(bool), gc.indptr, gc.indices, hops)

    # ----- sample per batch, then expand -----
    obs_mask = np.zeros((B, N), dtype=np.float32)
    if obs_mode == 'swarm':
        swarm = DroneSwarmSampler(gc,
                                  num_drones=int(swarm_num_drones),
                                  hops_radius=int(swarm_hops_radius),
                                  move_prob=float(swarm_move_prob),
                                  seed=int(obs_seed))
        # vary step per-sample so different obs masks; still deterministic given (seed,z,angle)
        for b in range(B):
            step = int(rng.integers(0, 2**31) + b)
            if int(swarm_traj_len) > 1:
                obs_mask[b] = swarm.mask_for_span(start_step=step, steps=int(swarm_traj_len))
            else:
                obs_mask[b] = swarm.mask_for_step(step)
    else:
        for b in range(B):
            m = sample_random_obs(gc, k=k, hops=int(obs_neighbor_hops), rng=rng, base_weights=base_weights)
            obs_mask[b] = m

    # Optional: in swarm mode, enforce target or minimum fraction by thinning/upsampling
    if obs_mode == 'swarm' and (swarm_target_frac is not None or swarm_min_frac is not None):
        target = None if swarm_target_frac is None else float(swarm_target_frac)
        min_frac = None if swarm_min_frac is None else float(swarm_min_frac)
        cand = gc.fluid_cand if gc.fluid_cand.size > 0 else np.arange(N, dtype=np.int32)
        for b in range(B):
            m = obs_mask[b].astype(bool)
            cur = m.sum() / float(N)
            # Thin if above target
            if (target is not None) and (cur > target) and (cur > 0):
                keep = int(max(1, round(target * N)))
                idx = np.where(m)[0]
                if idx.size > keep:
                    drop = np.setdiff1d(idx, rng.choice(idx, size=keep, replace=False), assume_unique=False)
                    m[drop] = False
            # Ensure minimum coverage if too small
            if (min_frac is not None) and (m.sum() / float(N) < min_frac):
                need = int(max(0, round(min_frac * N) - m.sum()))
                if need > 0:
                    add_cand = np.setdiff1d(cand, np.where(m)[0], assume_unique=False)
                    if add_cand.size > 0:
                        add = rng.choice(add_cand, size=min(need, add_cand.size), replace=False)
                        m[add] = True
            obs_mask[b] = m.astype(np.float32)

    obs_mask_3 = obs_mask[..., None]    # (B,N,1)
    obs_values = U_clean * obs_mask_3   # zeros elsewhere

    inputs = {'U_field': jnp.array(U_clean)}
    forcings = {
        'angle_deg': jnp.array(batch['angle_deg'].astype(np.int32)),
        'graph_structures': {k: jnp.array(v) for k, v in graph_struct.items()},
        'batch_size': U_clean.shape[0],
        'U_field_guiding': jnp.array(obs_values),
        'obs_mask': jnp.array(obs_mask),  # (B,N)
        # helpful to log:
        'obs_count': int(obs_mask.sum()),
    }
    return inputs, forcings, jnp.array(batch.get('z', -1))



def _mean_angular_similarity(u_true: np.ndarray, u_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Mean cosine similarity between 2D vectors at each node.
    u_true, u_pred: (N, 2)
    Returns scalar in [-1, 1]; 1 is perfect directional match.
    """
    t_norm = np.linalg.norm(u_true, axis=-1) + eps
    p_norm = np.linalg.norm(u_pred, axis=-1) + eps
    cos = (u_true[:, 0]*u_pred[:, 0] + u_true[:, 1]*u_pred[:, 1]) / (t_norm * p_norm)
    # clamp for safety
    cos = np.clip(cos, -1.0, 1.0)
    return float(np.mean(cos))


def _graph_ssim_speed(
    speed_true: np.ndarray,
    speed_pred: np.ndarray,
    senders: np.ndarray,
    receivers: np.ndarray,
    include_self: bool = True,
) -> float:
    """
     SSIM-style similarity for scalar fields on a graph using 1-hop neighborhoods.
     SSIM per node i is computed over window S_i = {i} ∪ N(i):
         mu_t, mu_p, sigma_t^2, sigma_p^2, sigma_tp from neighbor averages,
         SSIM_i = ((2 mu_t mu_p + C1)*(2 sigma_tp + C2)) / ((mu_t^2 + mu_p^2 + C1)*(sigma_t^2 + sigma_p^2 + C2))
     Returns the mean over nodes. Values typically in [0,1].
    """
    N = speed_true.shape[0]
    # Build degree and neighbor sums for t, p, t^2, p^2, and t*p
    deg = np.zeros(N, dtype=np.int32)
    sum_t = np.zeros(N, dtype=np.float64)
    sum_p = np.zeros(N, dtype=np.float64)
    sum_tt = np.zeros(N, dtype=np.float64)
    sum_pp = np.zeros(N, dtype=np.float64)
    sum_tp = np.zeros(N, dtype=np.float64)

    # include edges i<-j (aggregate j into i)
    np.add.at(deg, receivers, 1)
    np.add.at(sum_t, receivers, speed_true[senders])
    np.add.at(sum_p, receivers, speed_pred[senders])
    np.add.at(sum_tt, receivers, speed_true[senders]**2)
    np.add.at(sum_pp, receivers, speed_pred[senders]**2)
    np.add.at(sum_tp, receivers, speed_true[senders]*speed_pred[senders])

    if include_self:
        deg += 1
        sum_t += speed_true
        sum_p += speed_pred
        sum_tt += speed_true**2
        sum_pp += speed_pred**2
        sum_tp += speed_true*speed_pred

    deg_safe = np.maximum(deg, 1).astype(np.float64)
    mu_t = sum_t / deg_safe
    mu_p = sum_p / deg_safe
    var_t = sum_tt / deg_safe - mu_t**2
    var_p = sum_pp / deg_safe - mu_p**2
    cov_tp = sum_tp / deg_safe - mu_t*mu_p

    # Stabilize: dynamic range L from both fields
    v_all = np.concatenate([speed_true, speed_pred], axis=0)
    L = float(np.max(v_all) - np.min(v_all)) or 1.0
    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    num = (2.0 * mu_t * mu_p + C1) * (2.0 * cov_tp + C2)
    den = (mu_t**2 + mu_p**2 + C1) * (var_t + var_p + C2)
    ssim = num / (den + 1e-12)

    # clip to reasonable range
    ssim = np.clip(ssim, -1.0, 1.0)
    return float(np.mean(ssim))


def create_test_plot(original_data, predicted_data, coords, case_number, mae_error,
                     slice_z=None, obs_mask=None, observed_values=None):
    """
    Plot Original / Observed(sensors) / Predicted / RMSE for Ux, Uy.
    Unobserved points in the Observed panel are plotted as -inf (colored with cmap's 'under' color).
    `obs_mask`       : (N,) or (N,1) float/bool, optional
    `observed_values`: (N,2) optional (if not passed, we derive from mask+original)
    """
    rows = 4
    fig, axes = plt.subplots(rows, 2, figsize=(18, 22))
    slice_info = f" | Z={slice_z}" if slice_z is not None else ""
    fig.suptitle(
        f'Velocity Components – Case {case_number}{slice_info} | relRMSE={mae_error:.4f}',
        fontsize=16, fontweight='bold'
    )

    # Build boolean mask of observed nodes
    if obs_mask is not None:
        m = np.asarray(obs_mask).astype(bool).reshape(-1)
        obs_count = int(m.sum())
    else:
        m = None
        obs_count = 0

    # Derive observed values. IMPORTANT: mark unobserved as -inf
    if observed_values is None:
        if m is not None:
            ov = np.full_like(original_data, -np.inf, dtype=np.float32)
            ov[m] = original_data[m]
        else:
            ov = None
    else:
        # If the caller passed values with NaNs for unobserved, turn those into -inf
        ov = np.array(observed_values, copy=True)
        for j in range(ov.shape[1]):
            bad = ~np.isfinite(ov[:, j])
            ov[bad, j] = -np.inf

    component_names = ['Ux', 'Uy']
    for i, comp_name in enumerate(component_names):
        gt   = original_data[:, i]
        pred = predicted_data[:, i]
        error_field = np.abs(pred - gt)
        comp_rmse = float(np.sqrt(np.mean(error_field**2)))

        # 1) Original
        s1 = axes[0, i].scatter(coords[:, 0], coords[:, 1], c=gt, cmap='RdBu_r', s=1.0, alpha=0.9)
        axes[0, i].set_title(f'Original {comp_name}', fontsize=14)
        axes[0, i].set_aspect('equal'); axes[0, i].grid(True, alpha=0.3)
        cb1 = plt.colorbar(s1, ax=axes[0, i], shrink=0.8); cb1.set_label(f'Norm {comp_name}', fontsize=12)

        # 2) Observed (sensor) — show -inf at unobserved points via 'under' color
        if ov is not None:
            vals = ov[:, i]

            # Determine vmin/vmax from finite, non -inf observed values
            finite_obs = np.isfinite(vals) & (vals != -np.inf)
            if finite_obs.any():
                vmin = float(np.min(vals[finite_obs]))
                vmax = float(np.max(vals[finite_obs]))
            else:
                # Fallback to GT range if no finite observed values
                vmin = float(np.min(gt))
                vmax = float(np.max(gt))
                # avoid degenerate norm
                if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
                    vmin, vmax = -1.0, 1.0

            # Make a cmap with an 'under' color for -inf points
            cmap_obs = plt.get_cmap('RdBu_r').copy()
            cmap_obs.set_under('black', alpha=1.0)  # choose your under color

            norm_obs = mcolors.Normalize(vmin=vmin, vmax=vmax, clip=False)

            s2 = axes[1, i].scatter(
                coords[:, 0], coords[:, 1],
                c=vals, cmap=cmap_obs, norm=norm_obs,
                s=3.0, alpha=0.95
            )
            axes[1, i].set_title(f'Observed {comp_name} (count={obs_count})', fontsize=14)
            axes[1, i].set_aspect('equal'); axes[1, i].grid(True, alpha=0.3)
            # Use extend='min' to show the 'under' color box on the colorbar
            cb2 = plt.colorbar(s2, ax=axes[1, i], shrink=0.8, extend='min')
            cb2.set_label(f'Norm {comp_name}', fontsize=12)
        else:
            axes[1, i].text(0.5, 0.5, 'No sensor data', transform=axes[1, i].transAxes,
                            ha='center', va='center')
            axes[1, i].set_axis_off()

        # 3) Predicted
        s3 = axes[2, i].scatter(coords[:, 0], coords[:, 1], c=pred, cmap='RdBu_r', s=1.0, alpha=0.9)
        axes[2, i].set_title(f'Predicted {comp_name}', fontsize=14)
        axes[2, i].set_aspect('equal'); axes[2, i].grid(True, alpha=0.3)
        cb3 = plt.colorbar(s3, ax=axes[2, i], shrink=0.8); cb3.set_label(f'Norm {comp_name}', fontsize=12)

        # 4) RMSE map
        s4 = axes[3, i].scatter(coords[:, 0], coords[:, 1], c=error_field, cmap='viridis', s=1.0, alpha=0.95)
        axes[3, i].set_title(f'RMSE {comp_name} (mean={comp_rmse:.3f})', fontsize=14)
        axes[3, i].set_aspect('equal'); axes[3, i].grid(True, alpha=0.3)
        cb4 = plt.colorbar(s4, ax=axes[3, i], shrink=0.8); cb4.set_label('RMSE (spatial)', fontsize=12)

        # ------------------------ Stats boxes ------------------------
        orig_stats = f'Min: {gt.min():.3f}\nMax: {gt.max():.3f}\nMean: {gt.mean():.3f}'
        axes[0, i].text(0.02, 0.98, orig_stats, transform=axes[0, i].transAxes,
                        va='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

        pred_stats = f'Min: {pred.min():.3f}\nMax: {pred.max():.3f}\nMean: {pred.mean():.3f}'
        axes[2, i].text(0.02, 0.98, pred_stats, transform=axes[2, i].transAxes,
                        va='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

        error_stats = f'Max Error: {error_field.max():.3f}\nMean Error: {error_field.mean():.3f}\nRMSE: {comp_rmse:.3f}'
        axes[3, i].text(0.02, 0.98, error_stats, transform=axes[3, i].transAxes,
                        va='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    plt.tight_layout()
    return fig


# ----------------------------
# Testing (full sampling)
# ----------------------------
def run_testing(model, test_batch, slice_root,
                obs_frac=0.05, obs_frac_min=None, obs_frac_max=None,
                obs_neighbor_hops=1, obs_on_fluid_only=True, obs_seed=1234, eval_key=None, focus_x=0.1, focus_y=0.1,
                focus_boost=8.0, focus_trigger_frac=0.003,
                # observation mode & swarm params
                obs_mode: str = 'random',
                swarm_num_drones: int = 25,
                swarm_hops_radius: int = 1,
                swarm_move_prob: float = 0.9,
                swarm_traj_len: int = 1,
                swarm_target_frac: float | None = None,
                swarm_min_frac: float | None = None,
                ):
    case_number = int(test_batch['angle_deg'][0])
    slice_z = int(test_batch.get('z', 35))

    inputs, forcings, _ = prepare_batch(test_batch,
                                        obs_frac=obs_frac, obs_frac_min=obs_frac_min, obs_frac_max=obs_frac_max,
                                        obs_neighbor_hops=obs_neighbor_hops,
                                        obs_on_fluid_only=obs_on_fluid_only,
                                        obs_seed=obs_seed,
                                        obs_jax_key=eval_key,                 # set to None for fixed eval sensors
                                        focus_xy=(focus_x, focus_y),
                                        focus_boost=focus_boost,
                                        focus_trigger_frac=focus_trigger_frac,
                                        obs_mode=obs_mode,
                                        swarm_num_drones=swarm_num_drones,
                                        swarm_hops_radius=swarm_hops_radius,
                                        swarm_move_prob=swarm_move_prob,
                                        swarm_traj_len=swarm_traj_len,
                                        swarm_target_frac=swarm_target_frac,
                                        swarm_min_frac=swarm_min_frac,
                        )

    # Generate a pure-noise input for the sampler, matching target shape
    noise_shape = inputs['U_field'].shape
    rng_key = jax.random.PRNGKey(42)
    noisy_inputs = jax.random.normal(rng_key, noise_shape)

    # Add boundary values to forcings for sampling
    forcings['boundary_values'] = inputs['U_field']  # Provide clean boundary values

    try:
        predicted_output = model.full_sampling(noisy_inputs=noisy_inputs, forcings=forcings)
    except Exception as e:
        raise RuntimeError(f"Full sampling failed: {e}")

    original_data = np.array(inputs['U_field'][0])
    predicted_data = np.array(predicted_output[0])

    diff = predicted_data - original_data
    rmse_error = float(np.sqrt(np.mean(diff**2)))
    denom = float(np.sqrt(np.mean(original_data**2)))
    rel_rmse_error = rmse_error / denom if denom != 0 else 0.0

    ###############################################################
    # ----- New similarity metrics -----
    # 1) Mean angular similarity (directional agreement)
    mean_ang_sim = _mean_angular_similarity(original_data, predicted_data)

    # 2) Graph-SSIM on speed (uses o2o neighborhoods)
    # Extract graph edges from forcings (handles dict or list)
    gs = forcings['graph_structures']
    if isinstance(gs, (list, tuple)):
        gs = gs[0]
    senders = np.asarray(gs['o2o_senders']).astype(np.int64)
    receivers = np.asarray(gs['o2o_receivers']).astype(np.int64)

    speed_true = np.linalg.norm(original_data, axis=-1).astype(np.float64)   # (N,)
    speed_pred = np.linalg.norm(predicted_data, axis=-1).astype(np.float64)  # (N,)
    gssim = _graph_ssim_speed(speed_true, speed_pred, senders, receivers, include_self=True)
    ###############################################################

    # Coordinates for plotting
    case_name = f"case_{case_number}"
    slice_file = Path(slice_root) / case_name / f"slice_z_{int(slice_z)}.vtu"
    try:
        slc = pv.read(str(slice_file))
        coords = slc.points[:, :2]
    except Exception as e:
        print(f"Warning: Could not load coordinates from {slice_file}: {e}")
        coords = np.random.randn(len(original_data), 2)

    obs_mask = np.array(forcings.get('obs_mask'))[0] if 'obs_mask' in forcings else None
    observed_values = np.array(forcings.get('U_field_guiding'))[0] if 'U_field_guiding' in forcings else None

    fig = create_test_plot(
        original_data=original_data,
        predicted_data=predicted_data,
        coords=coords,
        case_number=case_number,
        mae_error=rel_rmse_error,
        slice_z=slice_z,
        obs_mask=obs_mask,
        observed_values=observed_values,
    )
    return fig, case_number, rel_rmse_error, gssim, mean_ang_sim


# ----------------------------
# Diffusion loss wrapper
# ----------------------------
def diffusion_loss(model, inputs, forcings):
    """
    Wrap GenSynth.model.loss for use in a simple train_step.
    Returns (loss, aux) where aux can be ignored for logging or used later.
    """
    clean_inputs = inputs['U_field']
    loss, aux = model.loss(clean_inputs=clean_inputs, forcings=forcings)
    return loss, aux


@nnx.jit
def train_step(model, optimizer, inputs, forcings):
    """
    Single training step:
      - compute diffusion loss
      - update optimizer/model
      - return scalar loss
    """
    def loss_fn(model):
        loss, _ = diffusion_loss(model, inputs, forcings)
        return loss, None

    (loss, _), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(grads)
    return loss


@nnx.jit
def eval_step(model, inputs, forcings):
    """
    Eval step that computes diffusion loss (no optimizer update).
    """
    loss, _ = diffusion_loss(model, inputs, forcings)
    return loss


# ----------------------------
# Main
# ----------------------------
def main():
    parser = argparse.ArgumentParser(description="Train GenSynth diffusion model (baseline-style script)")
    parser.add_argument('--model_name', type=str, required=True,
                        help='Name for the run (used for checkpoints and wandb run name)')
    parser.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k',
                        help='Root directory for sliced data')
    parser.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc',
                        help='Path to normalization statistics')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=3e-4, help='Base learning rate')
    parser.add_argument('--num_steps', type=int, default=200000, help='Number of training steps')
    parser.add_argument('--eval_every', type=int, default=1000, help='(Optional) eval frequency (loss only)')
    parser.add_argument('--testing_steps', type=int, default=1000,
                        help='Steps between testing/evaluation runs with full sampling')
    parser.add_argument('--enable_testing', action='store_true',
                        help='Enable testing during training (full sampling + plot)')
    parser.add_argument('--latent_size', type=int, default=64, help='Denoiser latent size')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints',
                        help='Directory to save checkpoints')
    parser.add_argument('--save_every', type=int, default=5000, help='Checkpoint save frequency')
    parser.add_argument('--wandb_project', type=str, default='GenSynth',
                        help='Weights & Biases project name')
    parser.add_argument('--seed', type=int, default=142730, help='Random seed')
    parser.add_argument('--beta1', type=float, default=0.9, help='AdamW beta1')
    parser.add_argument('--beta2', type=float, default=0.999, help='AdamW beta2')
    parser.add_argument('--weight_decay', type=float, default=0.1, help='AdamW weight decay')
    parser.add_argument('--grad_clip_norm', type=float, default=5.0, help='Global norm clip')
    parser.add_argument('--overfit_single', action='store_true', help='Overfit a single mesh/angle for debugging')
    parser.add_argument('--fixed_angle', type=int, default=None, help='Angle to overfit (1..360)')
    parser.add_argument('--fixed_z', type=int, default=None, help='Z to overfit (e.g., 35)')

    # NEW: evenly spaced subset controls
    parser.add_argument('--train_angle_stride', type=int, default=2,
                        help='Use every k-th angle for TRAIN (e.g., 18 => 20 angles).')
    parser.add_argument('--eval_angle_stride', type=int, default=2,
                        help='Use every k-th angle for EVAL/TEST. Use 1 to evaluate on all.')
    
    parser.add_argument('--obs_frac', type=float, default=0.0005,
                    help='Fraction of nodes marked as observed per sample (approx).')
    parser.add_argument('--obs_on_fluid_only', default=True,
                        help='Sample observations only on fluid nodes (type==1).')
    parser.add_argument('--obs_seed', type=int, default=1234)
    parser.add_argument('--unobs_loss_mult', type=float, default=3.0,
                        help='Multiply loss on unobserved nodes to focus on reconstruction.')
    parser.add_argument('--channel_weights', type=str, default="1.0,1.0",
                        help='Comma list like "1.0,1.5" to weight Ux,Uy in the loss.')
    
    parser.add_argument('--obs_frac_min', type=float, default=0.0016,
                    help='Lower bound fraction of nodes to observe (overrides --obs_frac if set)')
    parser.add_argument('--obs_frac_max', type=float, default=0.050,
                        help='Upper bound fraction of nodes to observe (overrides --obs_frac if set)')
    parser.add_argument('--obs_neighbor_hops', type=int, default=1,
                        help='Expand observations to this many o2o hops (sensor range). 0 = none.')

    parser.add_argument('--obs_frac_start', type=float, default=0.050,
                    help='Starting fraction of observed nodes at step 1.')
    parser.add_argument('--obs_frac_end', type=float, default=0.0016,
                        help='Ending fraction of observed nodes after anneal.')
    parser.add_argument('--obs_frac_anneal_steps', type=int, default=20000,
                        help='Steps over which to anneal obs fraction (then hold at end).')
    parser.add_argument('--obs_frac_jitter', type=float, default=0.0002,
                        help='Uniform +/- jitter around the per-step fraction.')
    
    parser.add_argument('--obs_focus_xy', type=str, default="500,500",
                    help='Half-widths (x_half,y_half) of the focus box. Default "500,500".')
    parser.add_argument('--obs_focus_boost', type=float, default=8.0,
                        help='Multiply sampling weight inside the box by this when focus triggers.')
    parser.add_argument('--obs_focus_trigger_frac', type=float, default=0.0003,
                        help='Enable focus weights when current obs fraction <= this.')

    # Observation mode controls
    parser.add_argument('--obs_mode', type=str, default='swarm', choices=['random', 'swarm'],
                        help="'random' for iid sensors or 'swarm' to simulate drone trajectories.")
    parser.add_argument('--swarm_num_drones', type=int, default=25, help='Number of drones in swarm mode.')
    parser.add_argument('--swarm_hops_radius', type=int, default=1, help='Observation radius in hops around drones.')
    parser.add_argument('--swarm_move_prob', type=float, default=0.9, help='Probability a drone moves each step.')
    parser.add_argument('--swarm_traj_len', type=int, default=100, help='Union observations over this many steps (trajectory length).')
    # Optional: enforce minimum or target fraction in swarm mode by thinning/upsampling
    parser.add_argument('--swarm_target_frac', type=float, default=0.032, help='If set, thin/upsample swarm mask to approach this fraction.')
    parser.add_argument('--swarm_min_frac', type=float, default=None, help='If set, ensure at least this fraction; will randomly add from candidates if needed.')

    # Classifier-free guidance knobs
    parser.add_argument('--cfg_dropout_prob', type=float, default=0.10,
                        help='Probability to drop observations during training (unconditional).')
    parser.add_argument('--guidance_scale', type=float, default=1.5,
                        help='Guidance scale during sampling. 1.0 uses conditional only; >1 amplifies; 0 uncond.')


    args = parser.parse_args()


    focus_x, focus_y = [float(v) for v in args.obs_focus_xy.split(',')]

    def obs_fraction_at(step, *, start, end, anneal_steps):
        # cosine from start -> end
        t = min(1.0, max(0.0, step / max(1, anneal_steps)))
        # cosine goes 1->0; map to start->end
        cos_term = 0.5 * (1.0 + np.cos(np.pi * t))
        return float(end + (start - end) * cos_term)

    # Set up directories
    checkpoint_root = Path(args.checkpoint_dir).resolve() / args.model_name
    checkpoint_root.mkdir(exist_ok=True, parents=True)

    # Initialize wandb
    wandb.init(project=args.wandb_project, name=args.model_name, config=vars(args))

    # Create datasets and iterators
    print("Creating datasets...")
    if args.overfit_single:
        # Overfit one example: batch_size=1, no shuffle, fixed angle/z
        train_ds = create_datasets(args.slice_root, args.norm_stats, batch_size=1, seed=args.seed, is_training=True,
                                   fixed_angle=args.fixed_angle, fixed_z=args.fixed_z, shuffle=True, angle_stride=1)
        test_ds = create_datasets(args.slice_root, args.norm_stats, batch_size=1, seed=args.seed + 2, is_training=False,
                                  fixed_angle=args.fixed_angle, fixed_z=args.fixed_z, shuffle=True, angle_stride=1)
        print("[Dataset] overfit_single=True → ignoring stride (train/eval on the same single case).")
    else:
        train_ds = create_datasets(
            args.slice_root, args.norm_stats, args.batch_size, args.seed,
            is_training=True, angle_stride=max(1, int(args.train_angle_stride)))
        test_ds = create_datasets(
            args.slice_root, args.norm_stats, batch_size=1, seed=args.seed + 2,
            is_training=False, angle_stride=max(1, int(args.eval_angle_stride)))
        # print(f"[Dataset] train_angle_stride={args.train_angle_stride} | eval_angle_stride={args.eval_angle_stride}")

    def forever(ds):
        while True:
            for b in ds:
                yield b

    train_iter = forever(train_ds)
    test_iter = forever(test_ds)

    # Example batch for model initialization
    example_batch = next(train_iter)
    example_inputs, example_forcings, _ = prepare_batch(example_batch, 
                                                    obs_frac=args.obs_frac,
                                                    obs_frac_min=args.obs_frac_min,
                                                    obs_frac_max=args.obs_frac_max,
                                                    obs_neighbor_hops=args.obs_neighbor_hops,
                                                    obs_on_fluid_only=args.obs_on_fluid_only,
                                                    obs_seed=args.obs_seed,
                                                    obs_mode=args.obs_mode,
                                                    swarm_num_drones=args.swarm_num_drones,
                                                    swarm_hops_radius=args.swarm_hops_radius,
                                                    swarm_move_prob=args.swarm_move_prob,
                                                    swarm_traj_len=args.swarm_traj_len,
                                                    swarm_target_frac=args.swarm_target_frac,
                                                    swarm_min_frac=args.swarm_min_frac,
                                                )

    # Build NN mesh (same 1D mesh pattern as your original)
    n_devices = jax.device_count()
    mesh = jax.sharding.Mesh(mesh_utils.create_device_mesh((n_devices,)), ('data',))

    # Configure + instantiate GenSynth
    den_cfg = DenoiserArchitectureConfig(
        latent_size=args.latent_size,
    )
    samp_cfg = SamplerConfig(guidance_scale=float(args.guidance_scale))
    noise_cfg = NoiseConfig(cfg_dropout_prob=float(args.cfg_dropout_prob))
    if args.overfit_single:
        # Disable CFG dropout for easier overfitting and reduce physics penalties
        noise_cfg = dataclasses.replace(noise_cfg)
        # Also reduce sampler steps for speed
        samp_cfg = dataclasses.replace(samp_cfg, num_steps=20)

    example_graph = example_forcings['graph_structures']
    if isinstance(example_graph, (list, tuple)):
        example_graph = example_graph[0]

    target_channels = example_inputs['U_field'].shape[-1]
    guiding_channels = example_inputs['U_field'].shape[-1]

    rngs = nnx.Rngs(args.seed)
    model = GenSynth(
        denoiser_architecture_config=den_cfg,
        sampler_config=samp_cfg,
        noise_config=noise_cfg,
        noise_encoder_config=NoiseEncoderConfig(),
        angle_encoder_config=AngleEncoderConfig(),
        rngs=rngs,
        mesh=mesh,
        example_graph_structures=example_graph,
        target_channels=target_channels,
        guiding_channels=guiding_channels,
    )

    # LR schedule: warmup + cosine
    warmup_steps = 5000
    warmup_fn = optax.linear_schedule(0.0, args.learning_rate, warmup_steps)
    cosine_fn = optax.cosine_decay_schedule(args.learning_rate, max(args.num_steps - warmup_steps, 1), alpha=0.0)
    scheduler = optax.join_schedules([warmup_fn, cosine_fn], boundaries=[warmup_steps])

    # Optimizer (AdamW + clip)
    tx = optax.chain(
        optax.clip_by_global_norm(args.grad_clip_norm),
        optax.adamw(learning_rate=scheduler, b1=args.beta1, b2=args.beta2, weight_decay=args.weight_decay),
    )
    optimizer = nnx.Optimizer(model, tx)

    # ----------------------------
    # Training loop
    # ----------------------------
    print("Starting training...")
    model.train()
    for step in range(1, args.num_steps + 1):
        batch = next(train_iter)

        # >>> CHANGED: compute annealed frac, then choose bounds differently pre/post anneal
        frac = obs_fraction_at(
            step,
            start=args.obs_frac_start,
            end=args.obs_frac_end,
            anneal_steps=args.obs_frac_anneal_steps
        )
        if step <= args.obs_frac_anneal_steps:
            lo = max(0.0, frac - args.obs_frac_jitter)
            hi = min(1.0, frac + args.obs_frac_jitter)
            swarm_target = args.swarm_target_frac  # keep target during curriculum if desired
        else:
            # AFTER ANNEAL: sample uniformly across full range so the model sees *all* counts
            lo = float(args.obs_frac_min)
            hi = float(args.obs_frac_max)
            swarm_target = None  # important: don't force a single fraction in swarm mode

        # >>> CHANGED: provide metadata so prepare_batch uses its anneal-then-uniform logic
        batch['current_step'] = step
        batch['obs_frac_anneal_steps'] = int(args.obs_frac_anneal_steps)
        batch['obs_frac_start'] = float(args.obs_frac_start)
        batch['obs_frac_end'] = float(args.obs_frac_end)

        # before calling prepare_batch each step:
        obs_key = rngs.noise()  # fresh key every iter

        inputs, forcings, _ = prepare_batch(
            batch,
            obs_frac=frac,
            obs_frac_min=lo,
            obs_frac_max=hi,
            obs_neighbor_hops=args.obs_neighbor_hops,
            obs_on_fluid_only=args.obs_on_fluid_only,
            obs_seed=args.obs_seed,              # legacy fallback if no key
            obs_jax_key=obs_key,                 # NEW
            focus_xy=(focus_x, focus_y),         # NEW
            focus_boost=args.obs_focus_boost,    # NEW
            focus_trigger_frac=args.obs_focus_trigger_frac,  # NEW
            obs_mode=args.obs_mode,
            swarm_num_drones=args.swarm_num_drones,
            swarm_hops_radius=args.swarm_hops_radius,
            swarm_move_prob=args.swarm_move_prob,
            swarm_traj_len=args.swarm_traj_len,
            swarm_target_frac=swarm_target,      # >>> CHANGED: None after anneal
            swarm_min_frac=args.swarm_min_frac,
        )

        # For overfit mode, force angle and z to fixed values
        if args.overfit_single and (args.fixed_angle is not None):
            forcings['angle_deg'] = jnp.array([int(args.fixed_angle)], dtype=jnp.int32)

        train_loss = train_step(model, optimizer, inputs, forcings)

        current_lr = scheduler(step)
        print(f"Step {step}: Train Loss = {float(train_loss):.6f}, LR = {float(current_lr):.6f}")
        wandb.log({
            "train_loss": float(train_loss),
            "lr": float(current_lr),
        })

        # Optional eval-only loss (no sampling)
        if args.eval_every and (step % args.eval_every == 0):
            eval_batch = next(test_iter)

            # >>> CHANGED: mirror the same bounds logic for eval
            if step <= args.obs_frac_anneal_steps:
                lo_eval = max(0.0, frac - args.obs_frac_jitter)
                hi_eval = min(1.0, frac + args.obs_frac_jitter)
                swarm_target_eval = args.swarm_target_frac
            else:
                lo_eval = float(args.obs_frac_min)
                hi_eval = float(args.obs_frac_max)
                swarm_target_eval = None

            # >>> CHANGED: attach metadata so eval uses the same anneal→uniform behavior
            eval_batch['current_step'] = step
            eval_batch['obs_frac_anneal_steps'] = int(args.obs_frac_anneal_steps)
            eval_batch['obs_frac_start'] = float(args.obs_frac_start)
            eval_batch['obs_frac_end'] = float(args.obs_frac_end)

            eval_key = rngs.noise()  # or None for deterministic eval
            eval_inputs, eval_forcings, _ = prepare_batch(
                                                        eval_batch,
                                                        obs_frac=frac, obs_frac_min=lo_eval, obs_frac_max=hi_eval,
                                                        obs_neighbor_hops=args.obs_neighbor_hops,
                                                        obs_on_fluid_only=args.obs_on_fluid_only,
                                                        obs_seed=args.obs_seed,
                                                        obs_jax_key=eval_key,                 # set to None for fixed eval sensors
                                                        focus_xy=(focus_x, focus_y),
                                                        focus_boost=args.obs_focus_boost,
                                                        focus_trigger_frac=args.obs_focus_trigger_frac,
                                                        obs_mode=args.obs_mode,
                                                        swarm_num_drones=args.swarm_num_drones,
                                                        swarm_hops_radius=args.swarm_hops_radius,
                                                        swarm_move_prob=args.swarm_move_prob,
                                                        swarm_traj_len=args.swarm_traj_len,
                                                        swarm_target_frac=swarm_target_eval,  # >>> CHANGED
                                                        swarm_min_frac=args.swarm_min_frac,
                                                    )

            model.eval()
            eval_loss = eval_step(model, eval_inputs, eval_forcings)
            wandb.log({"eval_loss": float(eval_loss), "step": step})
            model.train()

        # Full sampling test (plot) on schedule
        if args.enable_testing and ((step % args.testing_steps == 0) or (step == 1)):
            print(f"Running testing at step {step}...")
            try:
                test_batch = next(test_iter)
                test_clean = {k: (np.array(v) if hasattr(v, 'device') else v) for k, v in test_batch.items()}
                if args.overfit_single and args.fixed_angle is not None:
                    test_clean['angle_deg'] = np.array([int(args.fixed_angle)], dtype=np.int32)

                # >>> CHANGED: set post-anneal behavior for testing as well
                if step <= args.obs_frac_anneal_steps:
                    lo_test = max(0.0, frac - args.obs_frac_jitter)
                    hi_test = min(1.0, frac + args.obs_frac_jitter)
                    swarm_target_test = args.swarm_target_frac
                else:
                    lo_test = float(args.obs_frac_min)
                    hi_test = float(args.obs_frac_max)
                    swarm_target_test = None

                # >>> CHANGED: attach metadata for testing batch too (not strictly required by run_testing,
                # but harmless and keeps behavior consistent if you switch to prepare_batch directly).
                test_clean['current_step'] = step
                test_clean['obs_frac_anneal_steps'] = int(args.obs_frac_anneal_steps)
                test_clean['obs_frac_start'] = float(args.obs_frac_start)
                test_clean['obs_frac_end'] = float(args.obs_frac_end)

                model.eval()
                eval_key = rngs.noise()  # or None for fixed eval sensors
                fig, case_number, mae_error, gssim, mean_ang_sim = run_testing(
                                                model=model,
                                                test_batch=test_clean,
                                                slice_root=args.slice_root,
                                                obs_frac=frac,                # used if min/max not both set
                                                obs_frac_min=lo_test,         # >>> CHANGED
                                                obs_frac_max=hi_test,         # >>> CHANGED
                                                obs_neighbor_hops=args.obs_neighbor_hops,
                                                obs_on_fluid_only=args.obs_on_fluid_only,
                                                obs_seed=args.obs_seed,
                                                eval_key=eval_key,           # None for fixed eval sensors
                                                focus_x=focus_x,
                                                focus_y=focus_y,
                                                focus_boost=args.obs_focus_boost,
                                                focus_trigger_frac=args.obs_focus_trigger_frac,
                                                obs_mode=args.obs_mode,
                                                swarm_num_drones=args.swarm_num_drones,
                                                swarm_hops_radius=args.swarm_hops_radius,
                                                swarm_move_prob=args.swarm_move_prob,
                                                swarm_traj_len=args.swarm_traj_len,
                                                swarm_target_frac=swarm_target_test,  # >>> CHANGED
                                                swarm_min_frac=args.swarm_min_frac,
                                                
                                            )
                wandb.log({
                    "test_mae": mae_error,
                    "test_graph_ssim_speed": gssim,
                    "test_mean_angular_similarity": mean_ang_sim,
                    "test_case_number": case_number,
                    "test_plot": wandb.Image(fig),
                    "test_observed_nodes": int(np.array(forcings['obs_mask']).sum()) if 'forcings' in locals() else None,
                })

                plt.close(fig)
                print(f"Testing completed: Case {case_number}, MAE: {mae_error:.4f}")
                model.train()
            except Exception as e:
                print(f"Testing failed at step {step}: {e}")
                import traceback; traceback.print_exc()
                model.train()

        # Checkpointing
        if (step % args.save_every == 0) or (step == args.num_steps):
            model.eval()
            checkpoint_path = checkpoint_root / f"checkpoint_step_{step}"

            # Remove existing checkpoint dir if present
            if checkpoint_path.exists():
                import shutil
                shutil.rmtree(checkpoint_path)

            # Split and save state (same pattern as original)
            graphdef, rng_state, other_state = nnx.split(model, nnx.RngState, ...)
            other_state = jax.device_get(other_state)
            ocp.PyTreeCheckpointer().save(str(checkpoint_path), other_state)
            print(f"✔ Checkpoint saved at step {step}")
            model.train()

    model.eval()
    print("Training completed!")
    wandb.finish()


if __name__ == "__main__":
    main()


