#!/usr/bin/env python3
"""
Training script for MeshGraphNet baseline models.

This script trains the MeshGraphNet and MultiScaleMeshGraphNet models using direct
supervised learning (no diffusion). It uses the same dataset and evaluation logic
as the diffusion training but with a simpler MSE loss.
"""

import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import time
import argparse
import itertools
from pathlib import Path
from typing import Dict, Any
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend for cluster
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import orbax.checkpoint as ocp
import wandb
import numpy as np

# Set JAX memory options
jax.config.update('jax_enable_x64', False)  # Use float32 instead of float64

# Import modules
import sys
sys.path.append('..')
from gen_da.denoiser import DenoiserArchitectureConfig
from training.graph_dataset import make_dataset
from training.train import create_test_plot
from baselines.meshgraphnets import MeshGraphNet, MultiScaleMeshGraphNet
from training.obs_sampling import GraphCacheManager, sample_random_obs, DroneSwarmSampler

print("JAX sees these devices:", jax.devices())
print("Device count:", jax.device_count())


def create_datasets(slice_root, norm_stats, batch_size, seed, is_training=True, *, fixed_angle=None, fixed_z=None, shuffle=True, angle_stride=1):
    """Create training and validation datasets (aligned with training/train.py)."""
    ds = make_dataset(
        slice_root=slice_root,
        norm_stats_nc=norm_stats,
        batch_size=batch_size,
        shuffle=shuffle if is_training else False,
        seed=seed,
        is_training=is_training,
        fixed_angle=fixed_angle,
        fixed_z=fixed_z,
        drop_remainder=True,
        angle_stride=angle_stride,
    )
    return ds


def prepare_batch(
    batch,
    *,
    # observation fraction controls
    obs_frac: float = 0.05,
    obs_frac_min: float | None = None,
    obs_frac_max: float | None = None,
    obs_neighbor_hops: int = 1,
    obs_on_fluid_only: bool = True,
    obs_seed: int = 1234,
    # curriculum annealing
    current_step: int | None = None,
    obs_frac_anneal_steps: int | None = None,
    obs_frac_start: float | None = None,
    obs_frac_end: float | None = None,
    obs_frac_jitter: float = 0.0,
    # focus weighting (city-center)
    focus_xy: tuple[float, float] = (0.1, 0.1),
    focus_boost: float = 8.0,
    focus_trigger_frac: float = 0.003,
    # observation mode
    obs_mode: str = 'random',  # 'random' | 'swarm'
    swarm_num_drones: int = 10,
    swarm_hops_radius: int = 1,
    swarm_move_prob: float = 0.9,
    swarm_traj_len: int = 1,
    swarm_target_frac: float | None = None,
    swarm_min_frac: float | None = None,
):
    """Prepare target field, sensor observations, mask, and graph structures.

    Mirrors training/train.py prepare_batch for apple-to-apples comparison.
    """
    U_clean = np.array(batch['target_inputs'], dtype=np.float32)  # (B,N,2)
    B, N, C = U_clean.shape

    # Graph structures shared across the batch (numpy for sampling utils)
    gs = batch['graph_structures']
    if isinstance(gs, dict):
        graph_struct = gs
    else:
        graph_struct = gs[0] if isinstance(gs, list) and len(gs) > 0 else gs

    coords_o = np.asarray(graph_struct["original_coordinates"], dtype=np.float32)
    N_nodes = int(coords_o.shape[0])

    # Per-z graph cache
    z_field = batch.get('z', 35.0)
    z_value = float(z_field[0] if isinstance(z_field, (list, tuple, np.ndarray)) else z_field)
    gc = GraphCacheManager.get(graph_struct, z_value, fluid_only=obs_on_fluid_only)

    # Determine observation fraction f
    def jittered(val: float, jitter: float) -> float:
        if jitter <= 0.0:
            return val
        j = np.random.default_rng(obs_seed).uniform(-jitter, jitter)
        return float(np.clip(val + j, 0.0, 1.0))

    if (
        current_step is not None and obs_frac_anneal_steps is not None and
        obs_frac_start is not None and obs_frac_end is not None and
        (obs_frac_min is not None) and (obs_frac_max is not None)
    ):
        p = float(np.clip(current_step / max(1, obs_frac_anneal_steps), 0.0, 1.0))
        f_cur = obs_frac_start + p * (obs_frac_end - obs_frac_start)
        f = float(np.clip(jittered(f_cur, obs_frac_jitter), obs_frac_min, obs_frac_max))
    else:
        if (obs_frac_min is not None) and (obs_frac_max is not None):
            base = np.random.default_rng(obs_seed).uniform(obs_frac_min, obs_frac_max)
            f = float(np.clip(jittered(base, obs_frac_jitter), 0.0, 1.0))
        else:
            f = float(np.clip(jittered(obs_frac, obs_frac_jitter), 0.0, 1.0))

    k_base = max(1, int(round(f * max(1, gc.fluid_cand.size))))

    # Focus weights if fraction small
    enable_focus = (f <= float(focus_trigger_frac))
    base_weights = None
    if enable_focus:
        # Compute center and apply a rectangular focus box with half-widths focus_xy
        xy = gc.coords
        center = np.mean(xy, axis=0)
        x_half, y_half = float(focus_xy[0]), float(focus_xy[1])
        in_box = (np.abs(xy[:, 0] - center[0]) <= x_half) & (np.abs(xy[:, 1] - center[1]) <= y_half)
        base_weights = np.ones(gc.N, dtype=np.float32)
        base_weights[in_box] *= float(max(1.0, focus_boost))

    # Build obs mask per sample
    obs_mask = np.zeros((B, N_nodes), dtype=np.float32)
    for b in range(B):
        # Derive per-sample seed to avoid identical masks in batch
        local_seed = int((obs_seed * 0x9E3779B1 + b * 0x85EBCA6B + (current_step or 0)) & 0xFFFFFFFF)
        rng = np.random.default_rng(local_seed)

        if obs_mode == 'swarm':
            sampler = DroneSwarmSampler(
                gc,
                num_drones=int(swarm_num_drones),
                hops_radius=int(swarm_hops_radius),
                move_prob=float(swarm_move_prob),
                seed=int(obs_seed),
            )
            step_id = int(current_step or 0)
            if int(swarm_traj_len) <= 1:
                m = sampler.mask_for_step(step_id)
            else:
                m = sampler.mask_for_span(step_id, steps=int(swarm_traj_len))

            # Optional enforce target/min fraction by thinning/upsampling
            cur_cnt = int(np.sum(m))
            cur_frac = float(cur_cnt) / float(max(1, gc.N))
            target = swarm_target_frac if (swarm_target_frac is not None) else None
            minf = swarm_min_frac if (swarm_min_frac is not None) else None
            desired_frac = target if target is not None else minf
            if desired_frac is not None:
                desired_cnt = int(round(desired_frac * gc.N))
                if cur_cnt > desired_cnt and desired_cnt > 0:
                    picks = np.nonzero(m > 0)[0]
                    keep = rng.choice(picks, size=desired_cnt, replace=False)
                    m = np.zeros(gc.N, dtype=np.float32)
                    m[keep] = 1.0
                elif cur_cnt < desired_cnt:
                    # add random candidates from remaining
                    remaining = np.where(m == 0)[0]
                    add_k = max(0, desired_cnt - cur_cnt)
                    if remaining.size > 0 and add_k > 0:
                        add = rng.choice(remaining, size=min(add_k, remaining.size), replace=False)
                        base = np.zeros(gc.N, dtype=np.float32)
                        base[add] = 1.0
                        # expand those additions by obs_neighbor_hops to simulate sensor range
                        from training.obs_sampling import expand_mask_hops
                        add_expanded = expand_mask_hops(base.astype(bool), gc.indptr, gc.indices, int(obs_neighbor_hops))
                        m = np.clip(m + add_expanded, 0.0, 1.0)
        else:
            m = sample_random_obs(gc, k=k_base, hops=int(obs_neighbor_hops), rng=rng,
                                  base_weights=base_weights)

        # Clip and assign
        obs_mask[b] = np.clip(m, 0.0, 1.0)

    # Observed values (zeros elsewhere)
    obs_values = U_clean * obs_mask[..., None]

    # Inputs and forcings (match train.py keys)
    inputs = jnp.array(U_clean)  # targets for baseline supervised loss
    forcings = {
        'angle_deg': jnp.array(batch['angle_deg'].astype(np.int32)),
        'graph_structures': {k: jnp.array(v) for k, v in graph_struct.items()},
        'batch_size': int(B),
        'U_field_guiding': jnp.array(obs_values),
        'obs_mask': jnp.array(obs_mask),
        'obs_count': int(obs_mask.sum()),
    }
    return inputs, forcings


def compute_loss(model, inputs, forcings):
    """
    Compute MSE loss between predictions and targets.
    
    Args:
        model: MeshGraphNet or MultiScaleMeshGraphNet model
        inputs: Target velocity field (B, N_o, 2) - what we want to predict
        forcings: Forcings dictionary with guiding inputs
        
    Returns:
        loss: Mean squared error loss
        predictions: Model predictions (B, N_o, 2)
    """
    # Model predicts targets from guiding inputs only
    predictions = model(forcings)
    mse_loss = jnp.mean((predictions - inputs) ** 2)
    return mse_loss, predictions


@nnx.jit
def train_step(model, optimizer, inputs, forcings):
    """Perform a single training step."""
    def loss_fn(model):
        loss, predictions = compute_loss(model, inputs, forcings)
        return loss, predictions
    
    (loss, predictions), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(grads)
    
    return loss, predictions


@nnx.jit  
def eval_step(model, inputs, forcings):
    """Perform a single evaluation step."""
    loss, predictions = compute_loss(model, inputs, forcings)
    return loss, predictions


def create_model(model_type, config, rngs, mesh, example_graph_structures, target_channels=2, guiding_channels=4):
    """Create MeshGraphNet or MultiScaleMeshGraphNet model."""
    
    if model_type == "meshgraphnet":
        model = MeshGraphNet(
            config=config,
            rngs=rngs,
            mesh=mesh,
            example_graph_structures=example_graph_structures,
            target_channels=target_channels,
            guiding_channels=guiding_channels,
        )
    elif model_type == "multiscale":
        model = MultiScaleMeshGraphNet(
            config=config,
            rngs=rngs,
            mesh=mesh,
            example_graph_structures=example_graph_structures,
            target_channels=target_channels,
            guiding_channels=guiding_channels,
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    return model


def run_testing(model, test_batch, slice_root):
    """
    Run testing on the first sample of the batch (matching train.py logic).
    
    Args:
        model: The MeshGraphNet model
        test_batch: Batch containing the test data
        slice_root: Root directory for slice data
        
    Returns:
        fig: matplotlib figure, case_number: int, mae_error: float
    """
    # Get angle as case identifier and z from the batch
    case_number = int(test_batch.get('angle_deg', [0])[0])
    slice_z = int(test_batch.get('z', 35))
    
    # Prepare inputs and forcings for the model (use same obs params as training defaults)
    inputs, forcings = prepare_batch(
        test_batch,
        obs_frac=0.05,
        obs_frac_min=None,
        obs_frac_max=None,
        obs_neighbor_hops=1,
        obs_on_fluid_only=True,
        obs_seed=1234,
        focus_xy=(0.1, 0.1),
        focus_boost=8.0,
        focus_trigger_frac=0.003,
        obs_mode='random',
        swarm_num_drones=10,
        swarm_hops_radius=1,
        swarm_move_prob=0.9,
        swarm_traj_len=1,
    )
    
    # Get predictions
    predictions = model(forcings)
    
    # Extract original and predicted data (first sample only)
    original_data = inputs[0]  # Remove batch dimension (N_points, 2)
    predicted_data = predictions[0]  # Remove batch dimension (N_points, 2)
    
    # Compute relative RMSE error (same as train.py)
    diff = predicted_data - original_data
    rmse_error = float(jnp.sqrt(jnp.mean(diff**2)))
    denom = float(jnp.sqrt(jnp.mean(original_data**2)))
    rel_rmse_error = rmse_error / denom if denom != 0 else 0.0
    
    # Get coordinates for plotting
    case_name = f"case_{case_number}"
    slice_file = Path(slice_root) / case_name / f"slice_z_{int(slice_z)}.vtu"
    
    try:
        import pyvista as pv
        slc = pv.read(str(slice_file))
        coords = slc.points[:, :2]  # (N_o, 2) - x, y coordinates
    except Exception as e:
        print(f"Warning: Could not load coordinates from {slice_file}: {e}")
        # Create dummy coordinates if file loading fails
        coords = np.random.randn(len(original_data), 2)
    
    # Create comparison plot
    obs_mask = np.array(forcings.get('obs_mask'))[0] if 'obs_mask' in forcings else None
    observed_values = np.array(forcings.get('U_field_guiding'))[0] if 'U_field_guiding' in forcings else None

    fig = create_test_plot(
        original_data=original_data,
        predicted_data=predicted_data,
        coords=coords,
        case_number=case_number,
        mae_error=rel_rmse_error,
        slice_z=slice_z,
        obs_mask=obs_mask,
        observed_values=observed_values,
    )
    
    return fig, case_number, rel_rmse_error


def main():
    parser = argparse.ArgumentParser(description="Train MeshGraphNet baselines")
    parser.add_argument('--model_name', type=str, required=True, 
                       help='Name for the model (used for checkpoints and wandb)')
    parser.add_argument('--model_type', type=str, choices=['meshgraphnet', 'multiscale'], 
                       default='meshgraphnet', help='Type of model to train')
    parser.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k',
                       help='Root directory for sliced data')
    parser.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc',
                       help='Path to normalization statistics')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--num_steps', type=int, default=150000, help='Number of training steps')
    parser.add_argument('--eval_every', type=int, default=1000, help='Evaluation frequency')
    parser.add_argument('--save_every', type=int, default=1000, help='Checkpoint save frequency')
    parser.add_argument('--testing_steps', type=int, default=1000,
                       help='Steps between testing/evaluation runs')
    parser.add_argument('--enable_testing', action='store_true',
                       help='Enable testing during training')
    parser.add_argument('--latent_size', type=int, default=64, help='Latent size')
    parser.add_argument('--hidden_layers', type=int, default=1, help='Number of hidden layers')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_baselines',
                       help='Directory to save checkpoints')
    parser.add_argument('--wandb_project', type=str, default='meshgraphnet_baselines',
                       help='Wandb project name')
    parser.add_argument('--seed', type=int, default=302714, help='Random seed')
    parser.add_argument('--train_angle_stride', type=int, default=2, help='Angle subsampling for training')
    parser.add_argument('--eval_angle_stride', type=int, default=2, help='Angle subsampling for eval')
    parser.add_argument('--fixed_angle', type=int, default=None)
    parser.add_argument('--fixed_z', type=int, default=None)
    # Observation controls (aligned with training/train.py)
    parser.add_argument('--obs_frac', type=float, default=0.0005,
                        help='Approx fraction of nodes observed per sample.')
    parser.add_argument('--obs_on_fluid_only', type=bool, default=True,
                        help='Sample observations only on fluid nodes (type==1).')
    parser.add_argument('--obs_seed', type=int, default=1234)
    parser.add_argument('--obs_neighbor_hops', type=int, default=1,
                        help='Expand observations by this many o2o hops.')
    parser.add_argument('--obs_frac_min', type=float, default=0.0016,
                        help='Lower bound of observed fraction per batch.')
    parser.add_argument('--obs_frac_max', type=float, default=0.050,
                        help='Upper bound of observed fraction per batch.')
    parser.add_argument('--obs_frac_start', type=float, default=0.050,
                        help='Starting fraction for anneal at step 1.')
    parser.add_argument('--obs_frac_end', type=float, default=0.0016,
                        help='Ending fraction after anneal.')
    parser.add_argument('--obs_frac_anneal_steps', type=int, default=20000,
                        help='Steps over which to anneal obs fraction.')
    parser.add_argument('--obs_frac_jitter', type=float, default=0.0002,
                        help='Uniform +/- jitter around per-step fraction.')
    parser.add_argument('--obs_focus_xy', type=str, default="500,500",
                        help='Half-widths (x_half,y_half) of focus box.')
    parser.add_argument('--obs_focus_boost', type=float, default=8.0,
                        help='Weight boost inside focus box when triggered.')
    parser.add_argument('--obs_focus_trigger_frac', type=float, default=0.0003,
                        help='Enable focus when fraction <= this.')
    parser.add_argument('--obs_mode', type=str, default='swarm', choices=['random', 'swarm'],
                        help='Observation mode: random iid or drone swarm.')
    parser.add_argument('--swarm_num_drones', type=int, default=25)
    parser.add_argument('--swarm_hops_radius', type=int, default=1)
    parser.add_argument('--swarm_move_prob', type=float, default=0.9)
    parser.add_argument('--swarm_traj_len', type=int, default=1)
    parser.add_argument('--swarm_target_frac', type=float, default=None,
                        help='Target fraction for swarm masks (optional).')
    parser.add_argument('--swarm_min_frac', type=float, default=None,
                        help='Minimum fraction for swarm masks (optional).')
    
    args = parser.parse_args()
    
    # Set up directories
    checkpoint_dir = Path(args.checkpoint_dir).resolve() / args.model_name
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    
    # Initialize wandb
    wandb.init(
        project=args.wandb_project,
        name=args.model_name,
        config=vars(args)
    )
    
    # Create datasets
    print("Creating datasets...")
    train_ds = create_datasets(
        args.slice_root, args.norm_stats,
        batch_size=args.batch_size, seed=args.seed, is_training=True,
        fixed_angle=args.fixed_angle, fixed_z=args.fixed_z,
        shuffle=True, angle_stride=max(1, int(args.train_angle_stride))
    )

    # Create testing dataset (single sample for testing like train.py)
    test_ds = create_datasets(
        args.slice_root, args.norm_stats,
        batch_size=1, seed=args.seed + 2, is_training=False,
        fixed_angle=args.fixed_angle, fixed_z=args.fixed_z,
        shuffle=False, angle_stride=max(1, int(args.eval_angle_stride))
    )
    
    # Get example batch for model initialization
    train_iter = itertools.cycle(train_ds)
    
    # Create testing iterator
    def forever(ds):
        while True:
            for b in ds:
                yield b
    test_iter = forever(test_ds)
    
    example_batch = next(train_iter)
    # Parse focus box string to tuple
    try:
        fx_str = args.obs_focus_xy
        fx_parts = [float(p) for p in str(fx_str).split(',')]
        focus_xy = (fx_parts[0], fx_parts[1]) if len(fx_parts) >= 2 else (0.1, 0.1)
    except Exception:
        focus_xy = (0.1, 0.1)

    example_inputs, example_forcings = prepare_batch(
        example_batch,
        obs_frac=args.obs_frac,
        obs_frac_min=args.obs_frac_min,
        obs_frac_max=args.obs_frac_max,
        obs_neighbor_hops=args.obs_neighbor_hops,
        obs_on_fluid_only=bool(args.obs_on_fluid_only),
        obs_seed=args.obs_seed,
        current_step=1,
        obs_frac_anneal_steps=args.obs_frac_anneal_steps,
        obs_frac_start=args.obs_frac_start,
        obs_frac_end=args.obs_frac_end,
        obs_frac_jitter=args.obs_frac_jitter,
        focus_xy=focus_xy,
        focus_boost=args.obs_focus_boost,
        focus_trigger_frac=args.obs_focus_trigger_frac,
        obs_mode=args.obs_mode,
        swarm_num_drones=args.swarm_num_drones,
        swarm_hops_radius=args.swarm_hops_radius,
        swarm_move_prob=args.swarm_move_prob,
        swarm_traj_len=args.swarm_traj_len,
        swarm_target_frac=args.swarm_target_frac,
        swarm_min_frac=args.swarm_min_frac,
    )
    
    # Create model configuration
    config = DenoiserArchitectureConfig(
        latent_size=args.latent_size,
        hidden_layers=args.hidden_layers,
    )
    
    # Initialize model
    print(f"Initializing {args.model_type} model...")
    rngs = nnx.Rngs(args.seed)
    mesh = None  # Placeholder for mesh
    
    model = create_model(
        args.model_type, config, rngs, mesh,
        example_forcings['graph_structures'],
        target_channels=2, guiding_channels=2
    )
    
    # Learning rate scheduler with warmup and cosine decay (same as train.py)
    warmup_steps = 1000
    
    # 1. Linear warmup schedule from 0 to args.learning_rate
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=args.learning_rate,
        transition_steps=warmup_steps
    )
    
    # 2. Cosine decay schedule starting after warmup
    cosine_fn = optax.cosine_decay_schedule(
        init_value=args.learning_rate,
        decay_steps=args.num_steps - warmup_steps,
        alpha=0.0
    )
    
    # 3. Join warmup and cosine schedules
    scheduler = optax.join_schedules(
        schedules=[warmup_fn, cosine_fn],
        boundaries=[warmup_steps]
    )
    
    # Create optimizer with scheduler (same as train.py)
    tx = optax.chain(
        optax.clip_by_global_norm(5.0),  # Gradient clipping
        optax.adamw(
            learning_rate=scheduler,
            b1=0.9, b2=0.999,
            weight_decay=0.1
        )
    )
    optimizer = nnx.Optimizer(model, tx)
    
    # Training loop
    print("Starting training...")
    step = 0
    best_val_loss = float('inf')
    
    model.train()  # Set model to training mode
    for step in range(1, args.num_steps):
        # Training step
        batch = next(train_iter)
        inputs, forcings = prepare_batch(
            batch,
            obs_frac=args.obs_frac,
            obs_frac_min=args.obs_frac_min,
            obs_frac_max=args.obs_frac_max,
            obs_neighbor_hops=args.obs_neighbor_hops,
            obs_on_fluid_only=bool(args.obs_on_fluid_only),
            obs_seed=args.obs_seed,
            current_step=step,
            obs_frac_anneal_steps=args.obs_frac_anneal_steps,
            obs_frac_start=args.obs_frac_start,
            obs_frac_end=args.obs_frac_end,
            obs_frac_jitter=args.obs_frac_jitter,
            focus_xy=focus_xy,
            focus_boost=args.obs_focus_boost,
            focus_trigger_frac=args.obs_focus_trigger_frac,
            obs_mode=args.obs_mode,
            swarm_num_drones=args.swarm_num_drones,
            swarm_hops_radius=args.swarm_hops_radius,
            swarm_move_prob=args.swarm_move_prob,
            swarm_traj_len=args.swarm_traj_len,
            swarm_target_frac=args.swarm_target_frac,
            swarm_min_frac=args.swarm_min_frac,
        )
        
        train_loss, train_predictions = train_step(model, optimizer, inputs, forcings)
        
        # Logging
        
        current_lr = scheduler(step)
        print(f"Step {step}: Train Loss = {train_loss:.6f}, LR = {float(current_lr):.6f}")
        wandb.log({"train_loss": train_loss, "lr": float(current_lr), "step": step})
            
            
        
        # Run testing every testing_steps (same as train.py)
        if (args.enable_testing and ((step % args.testing_steps == 0) or (step == 1))):
            print(f"Running testing at step {step}...")
            try:
                # Get test batch
                test_batch = next(test_iter)
                
                # Create a non-sharded version of the batch for testing (if needed)
                test_batch_clean = {}
                for key, value in test_batch.items():
                    if hasattr(value, 'device'):
                        # Move from device to host if needed
                        test_batch_clean[key] = np.array(value)
                    else:
                        test_batch_clean[key] = value
                
                # Run testing
                model.eval()
                fig, case_number, mae_error = run_testing(
                    model=model, 
                    test_batch=test_batch_clean,
                    slice_root=args.slice_root,
                )
                
                # Log to wandb
                wandb.log({
                    'test_mae': mae_error,
                    'test_case_number': case_number,
                    'test_plot': wandb.Image(fig),
                    'step': step,
                })
                
                # Close the figure to save memory
                plt.close(fig)
                
                print(f"Testing completed: Case {case_number}, MAE: {mae_error:.4f}")
                model.train()  # Set back to training mode after testing

            except Exception as e:
                print(f"Testing failed at step {step}: {e}")
                import traceback
                traceback.print_exc()
                if 'model' in locals():
                    model.train()  # Ensure model is back in training mode even if testing fails
        
        # Regular checkpoint saving
        if step % args.save_every == 0 and step > 0:
            model.eval()
            checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}"
            
            # Remove existing checkpoint if it exists
            if checkpoint_path.exists():
                import shutil
                shutil.rmtree(checkpoint_path)
            
            # Split model state and save
            graphdef, rng_state, other_state = nnx.split(model, nnx.RngState, ...)
            other_state = jax.device_get(other_state)  # Ensure we have the state on host
            
            # Save the state
            checkpointer = ocp.PyTreeCheckpointer()
            checkpointer.save(str(checkpoint_path), other_state)
            print(f"✔ Checkpoint saved at step {step}")
            model.train()
    
    # Final cleanup - set model back to eval mode
    model.eval()
    print("Training completed!")
    wandb.finish()


if __name__ == "__main__":
    main()