"""
Distributed training script for the FlowCast model with a diffusion objective on the SEVIR dataset.
"""

import sys
import os
import copy
import argparse
import datetime
import random
import numpy as np
import wandb
import namegenerator
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from contextlib import nullcontext
import math

sys.path.append(os.getcwd())  # Add the current working directory to the path
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
os.environ["TOKENIZERS_PARALLELISM"] = "false"


from experiments.sevir.dataset.sevirfulldataset import (
    DynamicEncodedSequentialSevirDataset,
    dynamic_encoded_sequential_collate,
    DynamicSequentialSevirDataset,
    dynamic_sequential_collate,
    post_process_samples,
)
from common.utils.utils import EarlyStopping, compute_mean_std
from common.models.flowcast.cuboid_transformer_unet import (
    CuboidTransformerUNet,
)
from omegaconf import OmegaConf
from common.utils.utils import warmup_lambda
from common.diffusion.diffusion import (
    register_diffusion_buffers,
    p_sample_loop,
    p_losses,
    ddim_sample_loop,
)

from common.metrics.metrics_streaming_probabilistic import (
    MetricsAccumulator,
)
from common.utils.utils import (
    calculate_metrics,
    ema,
)
from experiments.sevir.display.cartopy import make_animation


# DDP Setup and Cleanup Functions
def setup_ddp():
    """Initializes the DDP environment."""
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        print(f"Initializing DDP: Rank {rank}/{world_size}, Local Rank {local_rank}")
        # URLs for rendezvous is typically handled by the launcher (torchrun) using env://
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            rank=rank,
            world_size=world_size,
            timeout=datetime.timedelta(hours=6),
        )
        torch.cuda.set_device(local_rank)
        # Add barrier to ensure all processes have initialized
        dist.barrier()
        return rank, world_size, local_rank, torch.device(f"cuda:{local_rank}")
    else:
        print("Not running in distributed mode. Using single device.")
        return (
            0,
            1,
            0,
            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        )  # Rank 0, World Size 1, Local Rank 0


def cleanup_ddp():
    """Cleans up the DDP environment."""
    if dist.is_initialized():
        dist.destroy_process_group()
        print("Cleaned up DDP.")


def reduce_tensor(tensor: torch.Tensor, world_size: int) -> torch.Tensor:
    """Averages a tensor across all DDP processes."""
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= world_size
    return rt


# Argument Parser (Remains Mostly the Same)
# Create argument parser
parser = argparse.ArgumentParser(description="Script for configuring hyperparameters.")


parser.add_argument(
    "--config",
    type=str,
    default="experiments/sevir/runner/flowcast_diffusion/flowcast_diffusion_config_easy.yaml",
)
parser.add_argument(
    "--train_file",
    type=str,
    default="datasets/sevir/data/sevir_full_latent_vae_kl1e4/nowcast_training_full.h5",
)
parser.add_argument(
    "--train_meta",
    type=str,
    default="datasets/sevir/data/sevir_full_latent_vae_kl1e4/nowcast_training_full_META.csv",
)
parser.add_argument(
    "--val_file",
    type=str,
    default="datasets/sevir/data/sevir_full_latent_vae_kl1e4/nowcast_validation_full.h5",
)
parser.add_argument(
    "--val_meta",
    type=str,
    default="datasets/sevir/data/sevir_full_latent_vae_kl1e4/nowcast_validation_full_META.csv",
)

parser.add_argument(
    "--partial_evaluation_file",
    type=str,
    default="datasets/sevir/data/sevir_full/nowcast_validation_full.h5",
)
parser.add_argument(
    "--partial_evaluation_meta",
    type=str,
    default="datasets/sevir/data/sevir_full/nowcast_validation_full_META.csv",
)


# Main Function
def main():
    """
    Main function to orchestrate the distributed training of the diffusion model.

    Handles the following steps:
    1. Parses command-line arguments and loads the YAML configuration.
    2. Sets up the distributed data parallel (DDP) environment.
    3. Initializes Weights & Biases for logging (on the main process).
    4. Creates datasets and distributed dataloaders for training and validation.
    5. Initializes the model (CuboidTransformerUNet), optimizer, and LR scheduler.
    6. Sets up EMA model handling and early stopping.
    7. Runs the main training loop, which includes:
        - Training for one epoch.
        - Running a validation loop.
        - Performing periodic partial evaluation with sample generation and metrics calculation.
        - Checking for early stopping.
    8. Cleans up the DDP environment upon completion.
    """
    args = parser.parse_args()
    config = OmegaConf.load(args.config)

    # --- DDP Initialization ---
    rank, world_size, local_rank, device = setup_ddp()
    is_main_process = rank == 0  # Flag for rank 0 process

    # Assign arguments to variables
    DEBUG_MODE = config.run_params.debug_mode
    RUN_STRING = config.run_params.run_string
    CARTOPY_FEATURES = config.partial_evaluation_params.cartopy_features
    FINAL_TRAINING = config.training_params.final_training

    # Create unique RUN_ID only on main process, then broadcast? Or use timestamp + namegenerator on all?
    # Simple approach: let each process generate, main process uses its ID for saving/wandb.
    run_id_base = (
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        + "_"
        + RUN_STRING
        + "_"
        + namegenerator.gen()
    )
    MAIN_RUN_ID = f"{run_id_base}_main"  # ID used by rank 0 for saving etc.

    DEBUG_PRINT_PREFIX = (
        f"[DEBUG Rank {rank}] " if DEBUG_MODE else f"[Rank {rank}] "
    )  # Add rank to debug prints

    ENABLE_WANDB = (
        config.run_params.enable_wandb and is_main_process
    )  # Only main process logs to WandB
    PARTIAL_EVALUATION = config.partial_evaluation_params.partial_evaluation
    PARTIAL_EVALUATION_INTERVAL = (
        config.partial_evaluation_params.partial_evaluation_interval
    )
    PARTIAL_EVALUATION_BATCHES = (
        config.partial_evaluation_params.partial_evaluation_batches
    )
    AUTOENCODER_CHECKPOINT = config.autoencoder_params.autoencoder_checkpoint
    THRESHOLDS = np.array([16, 74, 133, 160, 181, 219], dtype=np.float32)

    if (
        PARTIAL_EVALUATION and AUTOENCODER_CHECKPOINT is None and is_main_process
    ):  # Check only on main
        # No need to raise on all processes
        raise ValueError(
            "Partial Evaluation is enabled but Autoencoder Checkpoint is not provided"
        )

    # Target locations of sample training & testing data
    TRAIN_FILE = args.train_file
    TRAIN_META = args.train_meta
    VAL_FILE = args.val_file
    VAL_META = args.val_meta
    PRELOAD_MODEL = config.run_params.preload_model
    BATCH_SIZE = config.training_params.micro_batch_size  # PER GPU
    LEARNING_RATE = config.optimizer_params.learning_rate
    NUM_EPOCHS = config.training_params.num_epochs
    NUM_WORKERS = config.training_params.num_workers  # PER DATALOADER
    EARLY_STOPPING_PATIENCE = config.training_params.early_stopping_patience
    EARLY_STOPPING_METRIC = config.training_params.early_stopping_metric
    LAG_TIME = config.data_params.lag_time
    LEAD_TIME = config.data_params.lead_time
    TIME_SPACING = config.data_params.time_spacing
    GRAD_CLIP = config.training_params.gradient_clip_val

    # Assert that if early stopping metric is partial_csi_m or partial_mse, partial evaluation must be enabled
    if (
        EARLY_STOPPING_METRIC in ["partial_csi_m", "partial_mse"]
        and not PARTIAL_EVALUATION
        and is_main_process
    ):
        raise ValueError(
            f"Early stopping metric {EARLY_STOPPING_METRIC} requires partial evaluation to be enabled"
        )

    # Flowcast Config
    # Optimizer
    OPTIMIZER_TYPE = config.optimizer_params.optimizer_type
    WEIGHT_DECAY = config.optimizer_params.weight_decay

    # Scheduler
    SCHEDULER_TYPE = config.scheduler_params.scheduler_type
    LR_PLATEAU_FACTOR = config.scheduler_params.lr_plateau_factor
    LR_PLATEAU_PATIENCE = config.scheduler_params.lr_plateau_patience
    LR_COSINE_WARMUP_ITER_PERCENTAGE = (
        config.scheduler_params.lr_cosine_warmup_iter_percentage
    )
    LR_COSINE_MIN_WARMUP_LR_RATIO = (
        config.scheduler_params.lr_cosine_min_warmup_lr_ratio
    )
    LR_COSINE_MIN_LR_RATIO = config.scheduler_params.lr_cosine_min_lr_ratio

    # Asinh Transform
    ASINH_TRANSFORM = config.data_params.asinh_transform
    NORMALIZED_AUTOENCODER = config.autoencoder_params.normalized_autoencoder

    USE_FP16 = config.training_params.fp16

    # EMA Model Saving
    EMA_MODEL_SAVING = config.ema_model_saving_params.ema_model_saving
    EMA_MODEL_SAVING_DECAY = config.ema_model_saving_params.ema_model_saving_decay

    # Gradient Accumulation
    GRAD_ACCUMULATION_STEPS = (
        config.training_params.grad_accumulation_steps
    )  # Global effective batch size = BATCH_SIZE * world_size * GRAD_ACCUMULATION_STEPS

    # Load the config
    model_config = OmegaConf.to_object(config.latent_model)

    # Flowcast
    BASE_UNITS = model_config["base_units"]
    SCALE_ALPHA = model_config["scale_alpha"]
    NUM_HEADS = model_config["num_heads"]
    ATTN_DROP = model_config["attn_drop"]
    PROJ_DROP = model_config["proj_drop"]
    FFN_DROP = model_config["ffn_drop"]
    DOWNSAMPLE = model_config["downsample"]
    DOWNSAMPLE_TYPE = model_config["downsample_type"]
    UPSAMPLE_TYPE = model_config["upsample_type"]
    UPSAMPLE_KERNEL_SIZE = model_config["upsample_kernel_size"]
    DEPTH = model_config["depth"]
    BLOCK_ATTN_PATTERNS = [model_config["self_pattern"]] * len(DEPTH)
    NUM_GLOBAL_VECTORS = model_config["num_global_vectors"]
    USE_GLOBAL_VECTOR_FFN = model_config["use_global_vector_ffn"]
    USE_GLOBAL_SELF_ATTN = model_config["use_global_self_attn"]
    SEPARATE_GLOBAL_QKV = model_config["separate_global_qkv"]
    GLOBAL_DIM_RATIO = model_config["global_dim_ratio"]
    SELF_PATTERN = model_config["self_pattern"]
    FFN_ACTIVATION = model_config["ffn_activation"]
    GATED_FFN = model_config["gated_ffn"]
    NORM_LAYER = model_config["norm_layer"]
    PADDING_TYPE = model_config["padding_type"]
    CHECKPOINT_LEVEL = model_config["checkpoint_level"]
    POS_EMBED_TYPE = model_config["pos_embed_type"]
    USE_RELATIVE_POS = model_config["use_relative_pos"]
    SELF_ATTN_USE_FINAL_PROJ = model_config["self_attn_use_final_proj"]
    ATTN_LINEAR_INIT_MODE = model_config["attn_linear_init_mode"]
    FFN_LINEAR_INIT_MODE = model_config["ffn_linear_init_mode"]
    FFN2_LINEAR_INIT_MODE = model_config["ffn2_linear_init_mode"]
    ATTN_PROJ_LINEAR_INIT_MODE = model_config["attn_proj_linear_init_mode"]
    CONV_INIT_MODE = model_config["conv_init_mode"]
    DOWN_UP_LINEAR_INIT_MODE = model_config["down_up_linear_init_mode"]
    GLOBAL_PROJ_LINEAR_INIT_MODE = model_config["global_proj_linear_init_mode"]
    NORM_INIT_MODE = model_config["norm_init_mode"]
    TIME_EMBED_CHANNELS_MULT = model_config["time_embed_channels_mult"]
    TIME_EMBED_USE_SCALE_SHIFT_NORM = model_config["time_embed_use_scale_shift_norm"]
    TIME_EMBED_DROPOUT = model_config["time_embed_dropout"]
    UNET_RES_CONNECT = model_config["unet_res_connect"]

    # Diffusion Config
    diffusion_config = OmegaConf.to_object(config.diffusion_params)
    TIMESTEPS = diffusion_config["timesteps"]
    BETA_SCHEDULE = diffusion_config["beta_schedule"]
    CLIP_DENOISED = diffusion_config["clip_denoised"]
    LINEAR_START = diffusion_config["linear_start"]
    LINEAR_END = diffusion_config["linear_end"]
    COSINE_S = diffusion_config["cosine_s"]
    GIVEN_BETAS = diffusion_config["given_betas"]
    ORIGINAL_ELBO_WEIGHT = diffusion_config["original_elbo_weight"]
    L_SIMPLE_WEIGHT = diffusion_config["l_simple_weight"]
    P2_GAMMA = diffusion_config.get("p2_gamma", 0.5)
    P2_K = diffusion_config.get("p2_k", 1.0)
    # Partial evaluation DDIM config
    PARTIAL_USE_DDIM = config.partial_evaluation_params.get("use_ddim", False)
    PARTIAL_DDIM_NUM_STEPS = config.partial_evaluation_params.get("ddim_num_steps", 50)
    PARTIAL_DDIM_ETA = config.partial_evaluation_params.get("ddim_eta", 1.0)

    # --- Print Config only on Main Process ---
    if is_main_process:
        print(f"--- Distributed Training Config ---")
        print(f"World Size: {world_size}")
        print(f"Batch Size PER GPU: {BATCH_SIZE}")
        print(f"Global Batch Size (before accumulation): {BATCH_SIZE * world_size}")
        print(f"Gradient Accumulation Steps: {GRAD_ACCUMULATION_STEPS}")
        print(
            f"Effective Global Batch Size: {BATCH_SIZE * world_size * GRAD_ACCUMULATION_STEPS}"
        )
        print(f"FP16 Enabled: {USE_FP16}")
        print(f"-----------------------------------")
        print(f"{DEBUG_PRINT_PREFIX}Run ID (Main): {MAIN_RUN_ID}")
        print(f"{DEBUG_PRINT_PREFIX}Run String: {RUN_STRING}")
        print(f"{DEBUG_PRINT_PREFIX}Debug Mode: {DEBUG_MODE}")
        # ... (keep all other print statements inside this `if is_main_process:` block) ...
        print(f"{DEBUG_PRINT_PREFIX}Normalized Autoencoder: {NORMALIZED_AUTOENCODER}")
        print(
            f"{DEBUG_PRINT_PREFIX}Enable Wandb: {config.run_params.enable_wandb}"
        )  # Print the raw arg
        print(f"{DEBUG_PRINT_PREFIX}Partial Evaluation: {PARTIAL_EVALUATION}")
        print(
            f"{DEBUG_PRINT_PREFIX}Partial Evaluation Interval: {PARTIAL_EVALUATION_INTERVAL}"
        )
        print(
            f"{DEBUG_PRINT_PREFIX}Partial Evaluation Batches: {PARTIAL_EVALUATION_BATCHES}"
        )
        print(f"{DEBUG_PRINT_PREFIX}Autoencoder Checkpoint: {AUTOENCODER_CHECKPOINT}")
        print(f"{DEBUG_PRINT_PREFIX}Training File: {TRAIN_FILE}")
        print(f"{DEBUG_PRINT_PREFIX}Training Meta: {TRAIN_META}")
        print(f"{DEBUG_PRINT_PREFIX}Preload Model: {PRELOAD_MODEL}")
        # print(f"{DEBUG_PRINT_PREFIX}Batch Size: {BATCH_SIZE}") # Already printed above
        print(f"{DEBUG_PRINT_PREFIX}Learning Rate: {LEARNING_RATE}")
        print(f"{DEBUG_PRINT_PREFIX}Number of Epochs: {NUM_EPOCHS}")
        print(f"{DEBUG_PRINT_PREFIX}Number of Workers: {NUM_WORKERS}")
        print(f"{DEBUG_PRINT_PREFIX}Early Stopping Patience: {EARLY_STOPPING_PATIENCE}")
        print(f"{DEBUG_PRINT_PREFIX}Early Stopping Metric: {EARLY_STOPPING_METRIC}")
        print(f"{DEBUG_PRINT_PREFIX}Lag Time: {LAG_TIME}")
        print(f"{DEBUG_PRINT_PREFIX}Lead Time: {LEAD_TIME}")
        print(f"{DEBUG_PRINT_PREFIX}Time Spacing: {TIME_SPACING}")

        print(f"{DEBUG_PRINT_PREFIX}Gradient Clip Value: {GRAD_CLIP}")
        print(f"{DEBUG_PRINT_PREFIX}ASINH_TRANSFORM: {ASINH_TRANSFORM}")
        # print(f"{DEBUG_PRINT_PREFIX}FP16 (mixed precision): {USE_FP16}") # Printed above

        print(f"--------- {DEBUG_PRINT_PREFIX}Flowcast Config ---------")
        print(f"{DEBUG_PRINT_PREFIX}Base Units: {BASE_UNITS}")
        print(f"{DEBUG_PRINT_PREFIX}Scale Alpha: {SCALE_ALPHA}")
        print(f"{DEBUG_PRINT_PREFIX}Depth: {DEPTH}")
        print(f"{DEBUG_PRINT_PREFIX}Block Attn Patterns: {BLOCK_ATTN_PATTERNS}")

        print(f"{DEBUG_PRINT_PREFIX}Downsample: {DOWNSAMPLE}")
        print(f"{DEBUG_PRINT_PREFIX}Downsample Type: {DOWNSAMPLE_TYPE}")
        print(f"{DEBUG_PRINT_PREFIX}Upsample Type: {UPSAMPLE_TYPE}")
        print(f"{DEBUG_PRINT_PREFIX}Num Global Vectors: {NUM_GLOBAL_VECTORS}")
        print(
            f"{DEBUG_PRINT_PREFIX}ATTN_PROJ_LINEAR_INIT_MODE: {ATTN_PROJ_LINEAR_INIT_MODE}"
        )
        print(
            f"{DEBUG_PRINT_PREFIX}Global Proj Linear Init Mode: {GLOBAL_PROJ_LINEAR_INIT_MODE}"
        )
        print(f"{DEBUG_PRINT_PREFIX}Use Global Vector FFN: {USE_GLOBAL_VECTOR_FFN}")
        print(f"{DEBUG_PRINT_PREFIX}Use Global Self Attn: {USE_GLOBAL_SELF_ATTN}")
        print(f"{DEBUG_PRINT_PREFIX}Separate Global QKV: {SEPARATE_GLOBAL_QKV}")
        print(f"{DEBUG_PRINT_PREFIX}Global Dim Ratio: {GLOBAL_DIM_RATIO}")
        print(f"{DEBUG_PRINT_PREFIX}Self Pattern: {SELF_PATTERN}")
        print(f"{DEBUG_PRINT_PREFIX}Attn Drop: {ATTN_DROP}")
        print(f"{DEBUG_PRINT_PREFIX}Proj Drop: {PROJ_DROP}")
        print(f"{DEBUG_PRINT_PREFIX}FFN Drop: {FFN_DROP}")
        print(f"{DEBUG_PRINT_PREFIX}Num Heads: {NUM_HEADS}")
        print(f"{DEBUG_PRINT_PREFIX}FFN Activation: {FFN_ACTIVATION}")
        print(f"{DEBUG_PRINT_PREFIX}Gated FFN: {GATED_FFN}")
        print(f"{DEBUG_PRINT_PREFIX}Norm Layer: {NORM_LAYER}")
        print(f"{DEBUG_PRINT_PREFIX}Padding Type: {PADDING_TYPE}")
        print(f"{DEBUG_PRINT_PREFIX}Pos Embed Type: {POS_EMBED_TYPE}")
        print(f"{DEBUG_PRINT_PREFIX}Use Relative Pos: {USE_RELATIVE_POS}")
        print(
            f"{DEBUG_PRINT_PREFIX}Self Attn Use Final Proj: {SELF_ATTN_USE_FINAL_PROJ}"
        )
        print(f"{DEBUG_PRINT_PREFIX}Checkpoint Level: {CHECKPOINT_LEVEL}")
        print(f"{DEBUG_PRINT_PREFIX}Attn Linear Init Mode: {ATTN_LINEAR_INIT_MODE}")
        print(f"{DEBUG_PRINT_PREFIX}FFN Linear Init Mode: {FFN_LINEAR_INIT_MODE}")
        print(f"{DEBUG_PRINT_PREFIX}Conv Init Mode: {CONV_INIT_MODE}")
        print(
            f"{DEBUG_PRINT_PREFIX}Down Up Linear Init Mode: {DOWN_UP_LINEAR_INIT_MODE}"
        )
        print(f"{DEBUG_PRINT_PREFIX}Norm Init Mode: {NORM_INIT_MODE}")
        print(
            f"{DEBUG_PRINT_PREFIX}Time Embed Channels Mult: {TIME_EMBED_CHANNELS_MULT}"
        )
        print(
            f"{DEBUG_PRINT_PREFIX}Time Embed Use Scale Shift Norm: {TIME_EMBED_USE_SCALE_SHIFT_NORM}"
        )
        print(f"{DEBUG_PRINT_PREFIX}Time Embed Dropout: {TIME_EMBED_DROPOUT}")
        print(f"{DEBUG_PRINT_PREFIX}UNET Res Connect: {UNET_RES_CONNECT}")

        print(f"--------- {DEBUG_PRINT_PREFIX}Optimizer Config ---------")
        # print(f"{DEBUG_PRINT_PREFIX}Learning Rate: {LEARNING_RATE}") # Printed above
        print(f"{DEBUG_PRINT_PREFIX}Optimizer Type: {OPTIMIZER_TYPE}")
        print(f"{DEBUG_PRINT_PREFIX}Weight Decay: {WEIGHT_DECAY}")
        print(f"{DEBUG_PRINT_PREFIX}Scheduler Type: {SCHEDULER_TYPE}")
        print(f"{DEBUG_PRINT_PREFIX}LR Plateau Factor: {LR_PLATEAU_FACTOR}")
        print(f"{DEBUG_PRINT_PREFIX}LR Plateau Patience: {LR_PLATEAU_PATIENCE}")
        print(
            f"{DEBUG_PRINT_PREFIX}LR Cosine Warmup Iter Percentage: {LR_COSINE_WARMUP_ITER_PERCENTAGE}"
        )
        print(
            f"{DEBUG_PRINT_PREFIX}LR Cosine Min Warmup LR Ratio: {LR_COSINE_MIN_WARMUP_LR_RATIO}"
        )
        print(f"{DEBUG_PRINT_PREFIX}LR Cosine Min LR Ratio: {LR_COSINE_MIN_LR_RATIO}")

        print(f"--------- {DEBUG_PRINT_PREFIX}EMA Model Saving Config ---------")
        print(f"{DEBUG_PRINT_PREFIX}EMA Model Saving: {EMA_MODEL_SAVING}")
        if EMA_MODEL_SAVING:
            print(
                f"{DEBUG_PRINT_PREFIX}EMA Model Saving Decay: {EMA_MODEL_SAVING_DECAY}"
            )
        print(f"------------------------------------------------------------")

        print(f"--------- {DEBUG_PRINT_PREFIX}Diffusion Config ---------")
        print(f"{DEBUG_PRINT_PREFIX}Timesteps: {TIMESTEPS}")
        print(f"{DEBUG_PRINT_PREFIX}Beta Schedule: {BETA_SCHEDULE}")
        print(f"{DEBUG_PRINT_PREFIX}Clip Denoised: {CLIP_DENOISED}")
        print(f"{DEBUG_PRINT_PREFIX}Linear Start: {LINEAR_START}")
        print(f"{DEBUG_PRINT_PREFIX}Linear End: {LINEAR_END}")
        print(f"{DEBUG_PRINT_PREFIX}Cosine S: {COSINE_S}")
        print(f"{DEBUG_PRINT_PREFIX}Given Betas: {GIVEN_BETAS}")
        print(f"{DEBUG_PRINT_PREFIX}Original Elbo Weight: {ORIGINAL_ELBO_WEIGHT}")
        print(f"{DEBUG_PRINT_PREFIX}L Simple Weight: {L_SIMPLE_WEIGHT}")
        print(f"{DEBUG_PRINT_PREFIX}p2_gamma: {P2_GAMMA}")
        print(f"{DEBUG_PRINT_PREFIX}p2_k: {P2_K}")
        print(
            f"{DEBUG_PRINT_PREFIX}Partial Eval DDIM: use={PARTIAL_USE_DDIM}, steps={PARTIAL_DDIM_NUM_STEPS}, eta={PARTIAL_DDIM_ETA}"
        )
        print(f"------------------------------------------------------------")

    project_name = "sevir-nowcasting-cfm"
    # Initialize wandb only on the main process
    if ENABLE_WANDB:  # Already checks is_main_process
        # Make sure effective batch size is logged correctly
        config_dict = {
            "learning_rate": LEARNING_RATE,
            "batch_size_per_gpu": BATCH_SIZE,
            "world_size": world_size,
            "grad_accumulation_steps": GRAD_ACCUMULATION_STEPS,
            "effective_batch_size": BATCH_SIZE * world_size * GRAD_ACCUMULATION_STEPS,
            "num_epochs": NUM_EPOCHS,
            "num_workers": NUM_WORKERS,
            "early_stopping_patience": EARLY_STOPPING_PATIENCE,
            "lag_time": LAG_TIME,
            "lead_time": LEAD_TIME,
            "time_spacing": TIME_SPACING,
            "grad_clip": GRAD_CLIP,
            "dataset": "SEVIR",
            "model": "Flowcast-Diff",
            # "learning_rate": LEARNING_RATE, # Duplicate
            "optimizer_type": OPTIMIZER_TYPE,
            "weight_decay": WEIGHT_DECAY,
            "scheduler_type": SCHEDULER_TYPE,
            "lr_plateau_factor": LR_PLATEAU_FACTOR,
            "lr_plateau_patience": LR_PLATEAU_PATIENCE,
            "lr_cosine_warmup_iter_percentage": LR_COSINE_WARMUP_ITER_PERCENTAGE,
            "lr_cosine_min_warmup_lr_ratio": LR_COSINE_MIN_WARMUP_LR_RATIO,
            "lr_cosine_min_lr_ratio": LR_COSINE_MIN_LR_RATIO,
            "base_units": BASE_UNITS,
            "scale_alpha": SCALE_ALPHA,
            "depth": DEPTH,
            "block_attn_patterns": BLOCK_ATTN_PATTERNS,
            "downsample": DOWNSAMPLE,
            "downsample_type": DOWNSAMPLE_TYPE,
            "upsample_type": UPSAMPLE_TYPE,
            "num_global_vectors": NUM_GLOBAL_VECTORS,
            "use_global_vector_ffn": USE_GLOBAL_VECTOR_FFN,
            "use_global_self_attn": USE_GLOBAL_SELF_ATTN,
            "separate_global_qkv": SEPARATE_GLOBAL_QKV,
            "global_dim_ratio": GLOBAL_DIM_RATIO,
            "self_pattern": SELF_PATTERN,
            "attn_drop": ATTN_DROP,
            "proj_drop": PROJ_DROP,
            "ffn_drop": FFN_DROP,
            "num_heads": NUM_HEADS,
            "ffn_activation": FFN_ACTIVATION,
            "gated_ffn": GATED_FFN,
            "norm_layer": NORM_LAYER,
            "padding_type": PADDING_TYPE,
            "pos_embed_type": POS_EMBED_TYPE,
            "use_relative_pos": USE_RELATIVE_POS,
            "self_attn_use_final_proj": SELF_ATTN_USE_FINAL_PROJ,
            "checkpoint_level": CHECKPOINT_LEVEL,
            "attn_linear_init_mode": ATTN_LINEAR_INIT_MODE,
            "ffn_linear_init_mode": FFN_LINEAR_INIT_MODE,
            "conv_init_mode": CONV_INIT_MODE,
            "down_up_linear_init_mode": DOWN_UP_LINEAR_INIT_MODE,
            "norm_init_mode": NORM_INIT_MODE,
            # --- End Cuboid Attention WandB Config ---
            # Other non mentioned config
            "ASINH_TRANSFORM": ASINH_TRANSFORM,
            "fp16": USE_FP16,
            # adaptive weighting disabled
            "loss_weighting_gamma": None,
            "loss_weighting_epsilon": None,
            "initial_ema_loss": None,
            "ema_model_saving": EMA_MODEL_SAVING,
            "ema_model_saving_decay": EMA_MODEL_SAVING_DECAY,
            # --- Diffusion Config ---
            "timesteps": TIMESTEPS,
            "beta_schedule": BETA_SCHEDULE,
            "clip_denoised": CLIP_DENOISED,
            "linear_start": LINEAR_START,
            "linear_end": LINEAR_END,
            "cosine_s": COSINE_S,
            "given_betas": GIVEN_BETAS,
            "original_elbo_weight": ORIGINAL_ELBO_WEIGHT,
            "l_simple_weight": L_SIMPLE_WEIGHT,
            "p2_gamma": P2_GAMMA,
            "p2_k": P2_K,
            # --- End Diffusion Config ---
        }
        wandb.init(
            project=project_name,
            name=MAIN_RUN_ID,  # Use the main ID for the WandB run name
            config=config_dict,
        )

    # --- Artifacts and Model Saving (Main Process Only) ---
    # Define paths using MAIN_RUN_ID so all processes know the target, but only rank 0 writes.
    ARTIFACTS_FOLDER = f"artifacts/sevir/flowcast_diffusion/{MAIN_RUN_ID}"
    PLOTS_FOLDER = f"{ARTIFACTS_FOLDER}/plots"
    ANIMATIONS_FOLDER = f"{PLOTS_FOLDER}/animations"
    METRICS_FOLDER = f"{PLOTS_FOLDER}/metrics"
    MODEL_SAVE_DIR = f"{ARTIFACTS_FOLDER}/models"
    MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, "early_stopping_model" + ".pt")

    if is_main_process:
        os.makedirs(PLOTS_FOLDER, exist_ok=True)
        os.makedirs(ANIMATIONS_FOLDER, exist_ok=True)
        os.makedirs(METRICS_FOLDER, exist_ok=True)
        os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

    # Ensure non-main processes wait for directory creation
    if world_size > 1:
        dist.barrier()

    print(f"{DEBUG_PRINT_PREFIX}Using device: {device}")
    if device.type == "cpu" and is_main_process:  # Print warning only once
        print(DEBUG_PRINT_PREFIX + "Warning: CPU is used for computation!")

    # --- Seed Setting (Important for DDP Reproducibility) ---
    # Set seed consistently across all processes initially
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # Set seed for all GPUs

    # --- Dataset Creation (Same logic) ---
    train_dataset = DynamicEncodedSequentialSevirDataset(
        meta_csv=TRAIN_META,
        data_file=TRAIN_FILE,
        data_type="vil",
        raw_seq_len=49,
        lag_time=LAG_TIME,
        lead_time=LEAD_TIME,
        time_spacing=TIME_SPACING,
        stride=12,
        channel_last=True,
        debug_mode=DEBUG_MODE,
        asinh_transform=ASINH_TRANSFORM,
        transform=None,
    )
    val_dataset = DynamicEncodedSequentialSevirDataset(
        meta_csv=VAL_META,
        data_file=VAL_FILE,
        data_type="vil",
        raw_seq_len=49,
        lag_time=LAG_TIME,
        lead_time=LEAD_TIME,
        time_spacing=TIME_SPACING,
        stride=12,
        channel_last=True,
        debug_mode=DEBUG_MODE,
        asinh_transform=ASINH_TRANSFORM,
        transform=None,
    )

    # --- Distributed Samplers ---
    train_sampler = DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank, shuffle=True, seed=seed
    )
    # No need to shuffle val_sampler if val_loader shuffle is False, but consistency is good.
    val_sampler = DistributedSampler(
        val_dataset, num_replicas=world_size, rank=rank, shuffle=False
    )

    # --- DataLoaders with Distributed Samplers ---
    # Set shuffle=False in DataLoader, as DistributedSampler handles shuffling.
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,  # Sampler handles shuffling
        collate_fn=dynamic_encoded_sequential_collate,
        num_workers=NUM_WORKERS if not DEBUG_MODE else 0,
        pin_memory=True if not DEBUG_MODE else False,
        sampler=train_sampler,
        drop_last=True,  # Recommended for DDP
        persistent_workers=True if not DEBUG_MODE else False,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,  # Can potentially use a larger validation batch size
        shuffle=False,  # Sampler handles shuffling (or lack thereof)
        collate_fn=dynamic_encoded_sequential_collate,
        num_workers=NUM_WORKERS if not DEBUG_MODE else 0,
        pin_memory=True if not DEBUG_MODE else False,
        sampler=val_sampler,
        drop_last=False,  # Usually not necessary for validation
        persistent_workers=True if not DEBUG_MODE else False,
    )

    # --- Get Input Shape (Only on Main Process) ---
    input_shape = None
    output_shape = None
    if is_main_process:
        # Iterate briefly to get shapes
        temp_loader = (
            DataLoader(  # Temporary loader without sampler to get shapes easily
                train_dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                collate_fn=dynamic_encoded_sequential_collate,
                num_workers=0,
            )
        )
        for batch in temp_loader:
            inputs_cpu, outputs_cpu, _ = batch
            print(f"Inputs shape (from rank 0): {inputs_cpu.shape}")
            print(f"Outputs shape (from rank 0): {outputs_cpu.shape}")
            input_shape = inputs_cpu.shape
            output_shape = outputs_cpu.shape
            break
        del temp_loader  # Clean up temporary loader
    # --- Broadcast shapes to all processes ---
    # Use torch.distributed.broadcast_object_list for complex objects like shapes tuple
    shapes_list = [input_shape, output_shape]
    if world_size > 1:
        dist.broadcast_object_list(shapes_list, src=0)
    input_shape, output_shape = shapes_list[0], shapes_list[1]

    if input_shape is None or output_shape is None:
        raise RuntimeError("Could not determine input/output shapes.")

    # --- Compute/Load Mean/Std (Only on Main Process, then Broadcast) ---
    preload_model_state_dict = None
    preload_global_step = None
    preload_best_val_loss = None
    preload_std = None
    preload_mean = None
    preload_optimizer_state_dict = None
    mean = None
    std = None

    if PRELOAD_MODEL is not None:
        if is_main_process:
            print(f"{DEBUG_PRINT_PREFIX}Attempting to load checkpoint: {PRELOAD_MODEL}")
            try:
                # Load checkpoint on CPU first to avoid GPU mem issues on rank 0
                model_info = torch.load(PRELOAD_MODEL, map_location="cpu")
                preload_model_state_dict = model_info["model_state_dict"]
                preload_global_step = model_info.get(
                    "global_step", 0
                )  # Use get for safety
                preload_best_val_loss = model_info.get("best_metric", None)
                preload_std = model_info.get("std", None)
                preload_mean = model_info.get("mean", None)
                preload_optimizer_state_dict = model_info.get(
                    "optimizer_state_dict", None
                )
                print(
                    f"{DEBUG_PRINT_PREFIX}Successfully loaded model info from checkpoint"
                )
            except FileNotFoundError:
                print(
                    f"{DEBUG_PRINT_PREFIX}Preload model file not found: {PRELOAD_MODEL}. Starting from scratch."
                )
            except Exception as e:
                print(
                    f"{DEBUG_PRINT_PREFIX}Error loading checkpoint: {e}. Starting from scratch."
                )
        # Broadcast loaded info (or lack thereof)
        loaded_info = [
            preload_model_state_dict,
            preload_global_step,
            preload_best_val_loss,
            preload_std,
            preload_mean,
            preload_optimizer_state_dict,
        ]
        if world_size > 1:
            dist.broadcast_object_list(loaded_info, src=0)
        (
            preload_model_state_dict,
            preload_global_step,
            preload_best_val_loss,
            preload_std,
            preload_mean,
            preload_optimizer_state_dict,
        ) = loaded_info
    else:
        if is_main_process:
            print(f"{DEBUG_PRINT_PREFIX}No preload model specified.")

    if preload_mean is None or preload_std is None:
        if is_main_process:
            # Compute mean/std using a temporary loader on rank 0
            print(f"{DEBUG_PRINT_PREFIX}Computing mean and std...")
            temp_loader = DataLoader(
                train_dataset,
                batch_size=BATCH_SIZE * 4,
                shuffle=False,  # Use larger batch for faster compute
                collate_fn=dynamic_encoded_sequential_collate,
                num_workers=NUM_WORKERS // 2 if NUM_WORKERS > 1 else 0,
                pin_memory=False,
            )
            mean, std = compute_mean_std(temp_loader, channel_last=True)
            del temp_loader
            print(f"{DEBUG_PRINT_PREFIX}Computed Mean: {mean}")
            print(f"{DEBUG_PRINT_PREFIX}Computed Std: {std}")
        # Broadcast computed mean/std
        mean_std_list = [mean, std]
        if world_size > 1:
            dist.broadcast_object_list(mean_std_list, src=0)
        mean, std = mean_std_list[0], mean_std_list[1]
        if mean is None or std is None:
            raise RuntimeError("Mean/Std computation/broadcast failed.")
    else:
        raise RuntimeError("To be implemented.")

    # --- Model Creation ---
    # Create the model
    IN_TIMESTEPS = input_shape[1]  # Condition on the past frames and the noise
    OUTPUT_TIMESTEPS = output_shape[1]  # Number of output channels

    # Input Shape for the Flowcast model: (T_in+T_out, H_in, W_in, C_in)
    input_shape_flowcast = (
        IN_TIMESTEPS,
        input_shape[2],
        input_shape[3],
        input_shape[4],
    )
    output_shape_flowcast = (
        OUTPUT_TIMESTEPS,
        output_shape[2],
        output_shape[3],
        output_shape[4],
    )

    # Instantiate model on CPU first
    model = CuboidTransformerUNet(
        input_shape=input_shape_flowcast,
        target_shape=output_shape_flowcast,  # Exclude first dimension (batch size)
        base_units=BASE_UNITS,
        block_units=None,  # multiply by 2 when downsampling in each layer
        scale_alpha=SCALE_ALPHA,
        num_heads=NUM_HEADS,
        attn_drop=ATTN_DROP,
        proj_drop=PROJ_DROP,
        ffn_drop=FFN_DROP,
        downsample=DOWNSAMPLE,
        downsample_type=DOWNSAMPLE_TYPE,
        upsample_type=UPSAMPLE_TYPE,
        upsample_kernel_size=UPSAMPLE_KERNEL_SIZE,
        depth=DEPTH,
        block_attn_patterns=BLOCK_ATTN_PATTERNS,
        # global vectors
        num_global_vectors=NUM_GLOBAL_VECTORS,
        use_global_vector_ffn=USE_GLOBAL_VECTOR_FFN,
        use_global_self_attn=USE_GLOBAL_SELF_ATTN,
        separate_global_qkv=SEPARATE_GLOBAL_QKV,
        global_dim_ratio=GLOBAL_DIM_RATIO,
        # misc
        ffn_activation=FFN_ACTIVATION,
        gated_ffn=GATED_FFN,
        norm_layer=NORM_LAYER,
        padding_type=PADDING_TYPE,
        checkpoint_level=CHECKPOINT_LEVEL,
        pos_embed_type=POS_EMBED_TYPE,
        use_relative_pos=USE_RELATIVE_POS,
        self_attn_use_final_proj=SELF_ATTN_USE_FINAL_PROJ,
        # initialization
        attn_linear_init_mode=ATTN_LINEAR_INIT_MODE,
        ffn_linear_init_mode=FFN_LINEAR_INIT_MODE,
        ffn2_linear_init_mode=FFN2_LINEAR_INIT_MODE,
        attn_proj_linear_init_mode=ATTN_PROJ_LINEAR_INIT_MODE,
        conv_init_mode=CONV_INIT_MODE,
        down_linear_init_mode=DOWN_UP_LINEAR_INIT_MODE,
        up_linear_init_mode=DOWN_UP_LINEAR_INIT_MODE,
        global_proj_linear_init_mode=GLOBAL_PROJ_LINEAR_INIT_MODE,
        norm_init_mode=NORM_INIT_MODE,
        # timestep embedding for diffusion
        time_embed_channels_mult=TIME_EMBED_CHANNELS_MULT,
        time_embed_use_scale_shift_norm=TIME_EMBED_USE_SCALE_SHIFT_NORM,
        time_embed_dropout=TIME_EMBED_DROPOUT,
        unet_res_connect=UNET_RES_CONNECT,
        mean=mean,
        std=std,
    )

    # Load state dict if preloaded (before moving to GPU and wrapping)
    if preload_model_state_dict is not None:
        try:
            model.load_state_dict(preload_model_state_dict)
            if is_main_process:
                print(
                    f"{DEBUG_PRINT_PREFIX}Successfully loaded pre-trained model state dict."
                )
        except Exception as e:
            if is_main_process:
                print(
                    f"{DEBUG_PRINT_PREFIX}Error loading pre-trained model state dict: {e}. Model weights might be random."
                )

    # --- Move model to device and Wrap with DDP ---
    model = model.to(device)
    if world_size > 1:
        # find_unused_parameters not needed; use static_graph for performance
        model = DDP(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=False,
            static_graph=True,
            gradient_as_bucket_view=False,
        )
        if is_main_process:
            print(f"{DEBUG_PRINT_PREFIX}Wrapped model with DDP.")

    # --- EMA Model Handling ---
    ema_model = None
    if EMA_MODEL_SAVING:
        # Create EMA model on CPU, load state, then move to device
        # Crucially, do NOT wrap ema_model with DDP
        ema_model = copy.deepcopy(
            model.module if world_size > 1 else model
        )  # Get the underlying model
        ema_model.to(device)
        if is_main_process:
            print(f"{DEBUG_PRINT_PREFIX}Created EMA model.")

    # --- Calculate total steps ---
    # len(train_loader) will give the number of batches *per GPU*
    num_batches_per_epoch = len(train_loader)
    total_num_steps = int(NUM_EPOCHS * num_batches_per_epoch)
    if is_main_process:
        print(f"{DEBUG_PRINT_PREFIX}Batches per epoch per GPU: {num_batches_per_epoch}")
        print(f"{DEBUG_PRINT_PREFIX}Total training steps: {total_num_steps}")

    # --- Optimizer ---
    # Pass DDP model's parameters (it handles underlying model)
    if OPTIMIZER_TYPE == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    elif OPTIMIZER_TYPE == "adamw":
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
        )
    else:
        # Ensure error is raised only once
        if is_main_process:
            raise ValueError(f"Invalid optimizer type: {OPTIMIZER_TYPE}")
        else:  # Other processes should wait or exit cleanly
            dist.barrier()  # Wait for rank 0 potentially raising error
            cleanup_ddp()
            sys.exit(1)  # Exit if rank 0 raised error

    # Load optimizer state if available
    if preload_optimizer_state_dict is not None:
        try:
            optimizer.load_state_dict(preload_optimizer_state_dict)
            # Move optimizer state to the correct device
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
            if is_main_process:
                print(
                    f"{DEBUG_PRINT_PREFIX}Successfully loaded pre-trained optimizer state dict."
                )
        except Exception as e:
            if is_main_process:
                print(
                    f"{DEBUG_PRINT_PREFIX}Error loading pre-trained optimizer state dict: {e}. Optimizer state reset."
                )

    # --- Scheduler ---
    warmup_iter = int(np.round(LR_COSINE_WARMUP_ITER_PERCENTAGE * total_num_steps))

    if SCHEDULER_TYPE == "plateau":
        # Note: ReduceLROnPlateau needs the metric value, which needs aggregation in DDP
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=LR_PLATEAU_FACTOR,
            patience=LR_PLATEAU_PATIENCE,
        )
    elif SCHEDULER_TYPE == "cosine":
        warmup_scheduler = LambdaLR(
            optimizer,
            lr_lambda=warmup_lambda(
                warmup_steps=warmup_iter, min_lr_ratio=LR_COSINE_MIN_WARMUP_LR_RATIO
            ),
        )
        cosine_scheduler = CosineAnnealingLR(
            optimizer,
            T_max=(total_num_steps - warmup_iter),
            eta_min=LR_COSINE_MIN_LR_RATIO * LEARNING_RATE,
        )
        scheduler = SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[warmup_iter],
        )
    else:
        if is_main_process:
            raise ValueError(f"Invalid scheduler type: {SCHEDULER_TYPE}")
        else:
            dist.barrier()
            cleanup_ddp()
            sys.exit(1)

    # --- Diffusion ---

    # --- Partial Evaluation Setup (Main Process Only) ---
    ae_model = None
    val_sample_loader = None
    val_sample_sampler = None
    if PARTIAL_EVALUATION:
        if is_main_process and not os.path.exists(AUTOENCODER_CHECKPOINT):
            raise FileNotFoundError(
                f"[Rank 0] AE Model not found at {AUTOENCODER_CHECKPOINT}"
            )

        if is_main_process:
            print(
                f"{DEBUG_PRINT_PREFIX}Loading Autoencoder for evaluation on all ranks..."
            )

        from diffusers.models.autoencoders import AutoencoderKL

        ae_model = AutoencoderKL(
            in_channels=1,
            out_channels=1,
            down_block_types=config.autoencoder_params.down_block_types,
            up_block_types=config.autoencoder_params.up_block_types,
            block_out_channels=config.autoencoder_params.block_out_channels,
            act_fn=config.autoencoder_params.act_fn,
            latent_channels=config.autoencoder_params.latent_channels,
            norm_num_groups=config.autoencoder_params.norm_num_groups,
            layers_per_block=config.autoencoder_params.layers_per_block,
        )

        checkpoint = torch.load(AUTOENCODER_CHECKPOINT, map_location=device)
        ae_model.load_state_dict(checkpoint["model_state_dict"])
        ae_model = ae_model.to(device)
        ae_model.eval()
        if is_main_process:
            print(
                f"{DEBUG_PRINT_PREFIX}Autoencoder loaded successfully on rank {rank}."
            )

        VAL_SAMPLE_FILE = args.partial_evaluation_file
        VAL_SAMPLE_META = args.partial_evaluation_meta

        val_sample_dataset = DynamicSequentialSevirDataset(
            meta_csv=VAL_SAMPLE_META,
            data_file=VAL_SAMPLE_FILE,
            data_type="vil",
            raw_seq_len=49,
            lag_time=LAG_TIME,
            lead_time=LEAD_TIME,
            time_spacing=TIME_SPACING,
            stride=12,
            channel_last=False,
            debug_mode=DEBUG_MODE,
            log_transform=False,
        )

        if world_size > 1:
            val_sample_sampler = DistributedSampler(
                val_sample_dataset, num_replicas=world_size, rank=rank, shuffle=False
            )

        val_sample_loader = DataLoader(
            val_sample_dataset,
            batch_size=(BATCH_SIZE // 4 if BATCH_SIZE > 4 else BATCH_SIZE),
            shuffle=False,
            collate_fn=dynamic_sequential_collate,
            num_workers=NUM_WORKERS if not DEBUG_MODE else 0,
            pin_memory=True if not DEBUG_MODE else False,
            sampler=val_sample_sampler,
        )
        if is_main_process:
            print(
                f"{DEBUG_PRINT_PREFIX}Distributed test loader created for partial evaluation."
            )

    # --- Early Stopping (Initialize on Main Process Only) ---
    early_stopping = None
    best_val_loss_init = None  # Need to sync initial best loss if preloading

    if EARLY_STOPPING_METRIC == "val_loss":
        best_val_loss_init = (
            float("inf") if preload_best_val_loss is None else preload_best_val_loss
        )
        metric_direction = "minimize"
    elif EARLY_STOPPING_METRIC in ["partial_csi_m", "partial_mse"]:
        # Maximize CSI, Minimize MSE - Need to handle MSE inversion if needed by EarlyStopping class
        # Assuming EarlyStopping takes metric_direction
        best_val_loss_init = (
            -np.inf if preload_best_val_loss is None else preload_best_val_loss
        )
        metric_direction = (
            "maximize" if EARLY_STOPPING_METRIC == "partial_csi_m" else "minimize"
        )  # Check if MSE needs maximizing (-MSE)
        if (
            metric_direction == "minimize"
        ):  # If minimizing MSE, need to ensure comparison logic is correct
            best_val_loss_init = (
                float("inf") if preload_best_val_loss is None else preload_best_val_loss
            )
            print(
                f"{DEBUG_PRINT_PREFIX} Early stopping set to MINIMIZE {EARLY_STOPPING_METRIC}"
            )
        else:
            print(
                f"{DEBUG_PRINT_PREFIX} Early stopping set to MAXIMIZE {EARLY_STOPPING_METRIC}"
            )
    else:
        # Handle potential unknown metric
        if is_main_process:
            print(
                f"{DEBUG_PRINT_PREFIX} Warning: Unknown early stopping metric '{EARLY_STOPPING_METRIC}'. Defaulting to validation loss."
            )
        best_val_loss_init = (
            float("inf") if preload_best_val_loss is None else preload_best_val_loss
        )
        metric_direction = "minimize"

    # Broadcast the initial best metric value from rank 0
    initial_metric_tensor = torch.tensor(
        best_val_loss_init if best_val_loss_init is not None else float("inf"),
        device=device,
    )
    if world_size > 1:
        dist.broadcast(initial_metric_tensor, src=0)
    synced_best_val_loss_init = initial_metric_tensor.item()

    if is_main_process:
        early_stopping = EarlyStopping(
            patience=EARLY_STOPPING_PATIENCE,
            verbose=True,
            path=MODEL_SAVE_PATH,  # Only rank 0 saves the model
            initial_best_metric=synced_best_val_loss_init,
            metric_direction=metric_direction,
        )
        print(
            f"{DEBUG_PRINT_PREFIX}Initialized EarlyStopping with initial best metric: {synced_best_val_loss_init}, direction: {metric_direction}"
        )

    if is_main_process:
        print(f"{DEBUG_PRINT_PREFIX}Starting training, run id: {MAIN_RUN_ID}")

    # --- Global Step ---
    global_step = 0 if preload_global_step is None else preload_global_step

    # --- AMP Grad Scaler ---
    # Each process needs its own scaler
    scaler = torch.amp.GradScaler(device=device.type, enabled=USE_FP16)

    # --- Diffusion Buffers ---
    diffusion_buffers = register_diffusion_buffers(
        beta_schedule=BETA_SCHEDULE,
        timesteps=TIMESTEPS,
        linear_start=LINEAR_START,
        linear_end=LINEAR_END,
        cosine_s=COSINE_S,
        given_betas=GIVEN_BETAS,
        device=device,
    )

    # ===============================================================
    # %% Training Loop
    # ===============================================================
    for epoch in range(NUM_EPOCHS):
        # --- Set Sampler Epoch ---
        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)

        # --- Training Epoch ---
        model.train()  # Set model to train mode (including DDP wrapper)
        train_loss_accum = 0.0
        train_count = 0
        # no accumulators

        # Conditional tqdm for main process only
        train_bar_desc = f"Training Epoch {epoch} (Rank {rank})"
        train_bar = tqdm(
            train_loader,
            desc=train_bar_desc,
            disable=not is_main_process,
            position=rank,
            leave=False,
        )

        optimizer.zero_grad()  # Zero gradients at the start of the epoch / accumulation cycle

        # Access to variables from outer scope
        for batch_idx, batch in enumerate(train_bar):
            inputs, outputs, metadata = batch
            # Move data to the process's device
            x_start = outputs.to(device, non_blocking=True)
            x_cond = inputs.to(device, non_blocking=True)

            # Normalize (access underlying model via .module if DDP wrapped)
            current_model = model.module if world_size > 1 else model
            x_cond = current_model.normalize(x_cond)
            x_start = current_model.normalize(x_start)

            # Use DDP no_sync() on non-final accumulation steps
            ddp_sync_ctx = (
                model.no_sync()
                if (world_size > 1 and (batch_idx + 1) % GRAD_ACCUMULATION_STEPS != 0)
                else nullcontext()
            )

            # Mixed precision context
            with ddp_sync_ctx, torch.amp.autocast(
                device_type=device.type, enabled=USE_FP16
            ):
                # --- Diffusion Training ---
                t = torch.randint(
                    0, TIMESTEPS, (x_start.shape[0],), device=device
                ).long()
                noise = torch.randn_like(x_start)
                final_batch_loss, _ = p_losses(
                    model,
                    x_start,
                    t,
                    diffusion_buffers,
                    x_cond,
                    noise=noise,
                    clip_denoised=CLIP_DENOISED,
                    original_elbo_weight=ORIGINAL_ELBO_WEIGHT,
                    l_simple_weight=L_SIMPLE_WEIGHT,
                    p2_gamma=P2_GAMMA,
                    p2_k=P2_K,
                )

            # --- Backpropagation ---
            # Scale loss for accumulation
            scaled_loss = final_batch_loss / GRAD_ACCUMULATION_STEPS
            # scaler handles the scaling and potential DDP synchronization points within backward
            scaler.scale(scaled_loss).backward()

            # Store raw loss for logging average
            raw_loss_value = final_batch_loss.item()  # Get loss value before division
            train_loss_accum += raw_loss_value
            train_count += 1

            # --- Optimizer Step ---
            if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(
                train_loader
            ):
                # Unscale gradients before clipping
                scaler.unscale_(optimizer)
                # Clip gradients (applied to the model parameters that optimizer manages)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                # Optimizer step
                scaler.step(optimizer)
                # Update scaler for next iteration
                scaler.update()
                # Zero gradients *after* stepping and updating scaler
                optimizer.zero_grad()

                # Scheduler step (if applicable) - needs to happen *after* optimizer.step()
                if SCHEDULER_TYPE == "cosine":
                    scheduler.step()

                # --- EMA Model Update (moved here) ---
                if EMA_MODEL_SAVING and ema_model is not None:
                    ema(
                        (
                            model.module if world_size > 1 else model
                        ),  # Access underlying model
                        ema_model,
                        EMA_MODEL_SAVING_DECAY,
                    )

                # --- Logging (Main Process Only) ---
                if is_main_process:
                    current_lr = optimizer.param_groups[0]["lr"]
                    log_data = {
                        "training_loss_step": raw_loss_value,  # Log the loss of the last batch in step
                        "learning_rate": current_lr,
                    }
                    if ENABLE_WANDB:
                        wandb.log(log_data, step=global_step)

                # --- End Logging ---

            global_step += BATCH_SIZE  # Increment global step based on per-GPU batch size processed
            if is_main_process:
                train_bar.set_postfix({"training_loss": f"{raw_loss_value:.4f}"})

            if DEBUG_MODE and batch_idx >= 2:  # Check batch_idx for debug break
                if is_main_process:
                    print(
                        f"{DEBUG_PRINT_PREFIX}Debug break after {batch_idx+1} batches."
                    )
                break
        # --- End Training Batch Loop ---
        train_bar.close()

        if FINAL_TRAINING:

            val_train_bar_desc = f"Training with Val Epoch {epoch} (Rank {rank})"
            val_train_bar = tqdm(
                val_loader,
                desc=val_train_bar_desc,
                disable=not is_main_process,
                position=rank,
                leave=False,
            )
            # Access to variables from outer scope
            for batch_idx, batch in enumerate(val_train_bar):
                inputs, outputs, metadata = batch
                # Move data to the process's device
                x_start = outputs.to(device, non_blocking=True)
                x_cond = inputs.to(device, non_blocking=True)

                # Normalize (access underlying model via .module if DDP wrapped)
                current_model = model.module if world_size > 1 else model
                x_cond = current_model.normalize(x_cond)
                x_start = current_model.normalize(x_start)

                # Use DDP no_sync() on non-final accumulation steps
                ddp_sync_ctx = (
                    model.no_sync()
                    if (
                        world_size > 1
                        and (batch_idx + 1) % GRAD_ACCUMULATION_STEPS != 0
                    )
                    else nullcontext()
                )

                # Mixed precision context
                with ddp_sync_ctx, torch.amp.autocast(
                    device_type=device.type, enabled=USE_FP16
                ):
                    # --- Diffusion Training ---
                    t = torch.randint(
                        0, TIMESTEPS, (x_start.shape[0],), device=device
                    ).long()
                    noise = torch.randn_like(x_start)
                    final_batch_loss, _ = p_losses(
                        model,
                        x_start,
                        t,
                        diffusion_buffers,
                        x_cond,
                        noise=noise,
                        clip_denoised=CLIP_DENOISED,
                        original_elbo_weight=ORIGINAL_ELBO_WEIGHT,
                        l_simple_weight=L_SIMPLE_WEIGHT,
                        p2_gamma=P2_GAMMA,
                        p2_k=P2_K,
                    )

                # --- Backpropagation ---
                # Scale loss for accumulation
                scaled_loss = final_batch_loss / GRAD_ACCUMULATION_STEPS
                # scaler handles the scaling and potential DDP synchronization points within backward
                scaler.scale(scaled_loss).backward()

                # Store raw loss for logging average
                raw_loss_value = (
                    final_batch_loss.item()
                )  # Get loss value before division
                train_loss_accum += raw_loss_value
                train_count += 1

                # --- Optimizer Step ---
                if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0 or (
                    batch_idx + 1
                ) == len(train_loader):
                    # Unscale gradients before clipping
                    scaler.unscale_(optimizer)
                    # Clip gradients (applied to the model parameters that optimizer manages)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                    # Optimizer step
                    scaler.step(optimizer)
                    # Update scaler for next iteration
                    scaler.update()
                    # Zero gradients *after* stepping and updating scaler
                    optimizer.zero_grad()

                    # Scheduler step (if applicable) - needs to happen *after* optimizer.step()
                    if SCHEDULER_TYPE == "cosine":
                        scheduler.step()

                    # --- EMA Model Update (moved here) ---
                    if EMA_MODEL_SAVING and ema_model is not None:
                        ema(
                            (
                                model.module if world_size > 1 else model
                            ),  # Access underlying model
                            ema_model,
                            EMA_MODEL_SAVING_DECAY,
                        )

                    # --- Logging (Main Process Only) ---
                    if is_main_process:
                        current_lr = optimizer.param_groups[0]["lr"]
                        log_data = {
                            "training_loss_step": raw_loss_value,  # Log the loss of the last batch in step
                            "learning_rate": current_lr,
                        }
                        if ENABLE_WANDB:
                            wandb.log(log_data, step=global_step)

                    # --- End Logging ---

                global_step += BATCH_SIZE  # Increment global step based on per-GPU batch size processed
                if is_main_process:
                    val_train_bar.set_postfix(
                        {"training_loss": f"{raw_loss_value:.4f}"}
                    )

                if DEBUG_MODE and batch_idx >= 2:  # Check batch_idx for debug break
                    if is_main_process:
                        print(
                            f"{DEBUG_PRINT_PREFIX}Debug break after {batch_idx+1} batches."
                        )
                    break
            val_train_bar.close()

        # --- Aggregate and Calculate Average Training Loss for Logging ---
        avg_train_loss_local = (
            train_loss_accum / train_count if train_count > 0 else 0.0
        )
        loss_tensor = torch.tensor(avg_train_loss_local, device=device)
        if world_size > 1:
            loss_tensor = reduce_tensor(
                loss_tensor, world_size
            )  # Average across processes
        avg_train_loss_global = loss_tensor.item()

        # --- Aggregate and Log Adaptive Weighting Stats ---
        # Adaptive weighting disabled

        # ===============================================================
        # %% Validation Step
        # ===============================================================
        model.eval()  # Set model to eval mode
        if ema_model is not None:
            ema_model.eval()  # Also set EMA model to eval

        val_loss_accum = 0.0
        val_count = 0

        val_bar_desc = f"Validation Epoch {epoch} (Rank {rank})"
        val_bar = tqdm(
            val_loader,
            desc=val_bar_desc,
            disable=not is_main_process,
            position=rank,
            leave=False,
        )

        with torch.no_grad():  # No gradients needed for validation
            for batch in val_bar:
                inputs, outputs, metadata = batch
                x_start = outputs.to(device, non_blocking=True)
                x_cond = inputs.to(device, non_blocking=True)

                current_model = model.module if world_size > 1 else model
                x_cond = current_model.normalize(x_cond)
                x_start = current_model.normalize(x_start)

                # Use AMP context manager for consistency, although grads aren't computed
                with torch.amp.autocast(device_type=device.type, enabled=USE_FP16):
                    noise = torch.randn_like(x_start)
                    t = torch.randint(
                        0, TIMESTEPS, (x_start.shape[0],), device=device
                    ).long()

                    loss, _ = p_losses(
                        model,
                        x_start,
                        t,
                        diffusion_buffers,
                        x_cond,
                        noise=noise,
                        clip_denoised=CLIP_DENOISED,
                        original_elbo_weight=ORIGINAL_ELBO_WEIGHT,
                        l_simple_weight=L_SIMPLE_WEIGHT,
                        p2_gamma=P2_GAMMA,
                        p2_k=P2_K,
                    )

                val_loss_accum += loss.item() * inputs.size(
                    0
                )  # Accumulate weighted by batch size
                val_count += inputs.size(0)  # Accumulate total samples processed

                if is_main_process:
                    val_bar.set_postfix(
                        {"validation_loss": "{:.4f}".format(loss.item())}
                    )

                if (
                    DEBUG_MODE and val_count // BATCH_SIZE >= 3
                ):  # Check number of batches processed
                    if is_main_process:
                        print(f"{DEBUG_PRINT_PREFIX}Debug break during validation.")
                    break
        val_bar.close()

        # --- Aggregate Validation Loss Across Processes ---
        # Sum accumulated loss and counts across all GPUs
        val_loss_total_tensor = torch.tensor(val_loss_accum, device=device)
        val_count_total_tensor = torch.tensor(val_count, device=device)

        if world_size > 1:
            dist.all_reduce(val_loss_total_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(val_count_total_tensor, op=dist.ReduceOp.SUM)

        # Calculate global average validation loss
        avg_val_loss_global = (
            (val_loss_total_tensor / val_count_total_tensor).item()
            if val_count_total_tensor > 0
            else 0.0
        )

        # --- Scheduler Step (Plateau) ---
        if SCHEDULER_TYPE == "plateau":
            # Plateau scheduler needs the global average loss
            scheduler.step(avg_val_loss_global)

        # --- Partial Evaluation (Main Process Only) ---
        partial_eval_results = None
        should_run_partial_eval = PARTIAL_EVALUATION and (
            epoch % PARTIAL_EVALUATION_INTERVAL == 0
        )
        # All ranks enter barrier before rank-0-only partial eval
        if world_size > 1 and should_run_partial_eval:
            dist.barrier()

        if should_run_partial_eval:
            # Ensure cropping is off for full evaluation
            if is_main_process:
                print(
                    f"{DEBUG_PRINT_PREFIX}Running Distributed Partial Evaluation for Epoch {epoch}..."
                )

            eval_model = (
                ema_model
                if EMA_MODEL_SAVING and ema_model is not None
                else (model.module if world_size > 1 else model)
            )
            eval_model.eval()

            partial_eval_results = partial_evaluate_model(
                model=eval_model,
                device=device,
                val_sample_loader=val_sample_loader,
                thresholds=THRESHOLDS,
                global_step=global_step,
                epoch=epoch,
                ae_model=ae_model,
                normalized_autoencoder=NORMALIZED_AUTOENCODER,
                asinh_transform=ASINH_TRANSFORM,
                partial_evaluation_batches=PARTIAL_EVALUATION_BATCHES,
                lead_time=LEAD_TIME,
                enable_wandb=ENABLE_WANDB,
                wandb_instance=wandb if ENABLE_WANDB else None,
                debug_print_prefix=DEBUG_PRINT_PREFIX,
                plots_folder=METRICS_FOLDER,
                cartopy_features=CARTOPY_FEATURES,
                ema_model_evaluated=EMA_MODEL_SAVING and ema_model is not None,
                diffusion_buffers=diffusion_buffers,
                timesteps=TIMESTEPS,
                clip_denoised=CLIP_DENOISED,
                use_ddim=PARTIAL_USE_DDIM,
                ddim_num_steps=PARTIAL_DDIM_NUM_STEPS,
                ddim_eta=PARTIAL_DDIM_ETA,
                batch_size_autoencoder=(None if BATCH_SIZE > 2 else BATCH_SIZE),
            )
            if is_main_process:
                print(f"{DEBUG_PRINT_PREFIX}Partial Evaluation finished.")

        # All ranks wait until rank 0 finishes partial eval
        if world_size > 1 and should_run_partial_eval:
            dist.barrier()

        # --- Logging Epoch Results (Main Process Only) ---
        if is_main_process:
            current_lr = optimizer.param_groups[0]["lr"]  # Get current LR
            print(
                f"Finished Epoch {epoch} - Train Loss: {avg_train_loss_global:.4f}, Val Loss: {avg_val_loss_global:.4f}, LR: {current_lr:.6f}"
            )
            if ENABLE_WANDB:
                log_data = {
                    "epoch": epoch,
                    "avg_training_loss": avg_train_loss_global,
                    "avg_validation_loss": avg_val_loss_global,
                    "learning_rate": current_lr,  # Log LR at epoch end
                }
                wandb.log(log_data, step=global_step)  # Log against global step

        # --- Early Stopping Check (Main Process Only) ---
        stop_signal_tensor = torch.tensor(
            0, device=device, dtype=torch.int
        )  # 0 = continue, 1 = stop
        if is_main_process and early_stopping is not None:
            # Determine the metric to use for early stopping
            if EARLY_STOPPING_METRIC == "val_loss":
                current_metric = avg_val_loss_global
            elif EARLY_STOPPING_METRIC == "partial_mse":
                # Handle case where partial eval didn't run or failed
                current_metric = (
                    partial_eval_results.get("mse_mean", float("inf"))
                    if partial_eval_results
                    else float("inf")
                )
            elif EARLY_STOPPING_METRIC == "partial_csi_m":
                current_metric = (
                    partial_eval_results.get("csi_m", -np.inf)
                    if partial_eval_results
                    else -np.inf
                )
            else:  # Default to val_loss if metric unknown
                current_metric = avg_val_loss_global

            # Prepare model for saving (use EMA if enabled, otherwise unwrapped DDP model)
            model_to_save = (
                ema_model
                if EMA_MODEL_SAVING and ema_model is not None
                else (model.module if world_size > 1 else model)
            )

            # Call early stopping check
            early_stopping(current_metric, model_to_save, optimizer, epoch, global_step)

            if early_stopping.early_stop:
                print(f"{DEBUG_PRINT_PREFIX}Early stopping triggered.")
                stop_signal_tensor = torch.tensor(1, device=device, dtype=torch.int)

        # Broadcast stop signal from rank 0 to all processes
        if world_size > 1:
            dist.broadcast(stop_signal_tensor, src=0)

        # All processes check the signal
        if stop_signal_tensor.item() == 1:
            if is_main_process:
                print("Early stopping condition met. Finalizing training.")
            break  # Exit epoch loop

        # Ensure all processes finish the epoch steps before starting the next one
        if world_size > 1:
            dist.barrier()
    # --- End Epoch Loop ---

    # --- Final Cleanup ---
    if is_main_process:
        print(f"{DEBUG_PRINT_PREFIX}Finished training, run id: {MAIN_RUN_ID}")
        if ENABLE_WANDB:
            wandb.finish()

    cleanup_ddp()


# %% Helper function for partial evaluation
def partial_evaluate_model(
    model,  # Should be the *unwrapped* model (or EMA model)
    device,  # Device of the main process (rank 0)
    val_sample_loader,  # Dataloader *only* on main process
    thresholds,
    global_step,
    epoch,
    ae_model,  # AE model *only* on main process device
    normalized_autoencoder,
    asinh_transform,
    partial_evaluation_batches,
    lead_time,
    enable_wandb,  # Flag
    wandb_instance,  # Actual wandb object
    debug_print_prefix,  # Rank 0 prefix
    plots_folder,  # Output folder path
    cartopy_features,
    ema_model_evaluated,  # Bool flag
    diffusion_buffers,
    timesteps,
    clip_denoised,
    use_ddim=False,
    ddim_num_steps=100,
    ddim_eta=0.0,
    batch_size_autoencoder=None,
):
    """
    Runs a partial evaluation on a subset of the validation data.

    This function is executed on all ranks. It generates predictions for a few
    batches, decodes them back to pixel space using a pre-trained autoencoder,
    and computes various nowcasting metrics. Metrics are aggregated across all
    DDP processes, and the final results and sample animations are logged by the
    main process.

    Args:
        model: The unwrapped diffusion model (or EMA model) to evaluate.
        device: The device for the current process.
        val_sample_loader: DataLoader for the validation subset.
        thresholds (np.ndarray): Rainfall intensity thresholds for CSI calculation.
        global_step (int): Current global training step for logging.
        epoch (int): Current epoch number for logging and saving artifacts.
        ae_model: The pre-trained AutoencoderKL model for decoding.
        normalized_autoencoder (bool): Whether the autoencoder expects normalized input.
        asinh_transform (bool): Whether to apply an inverse asinh transform to latents.
        partial_evaluation_batches (int): Number of batches to evaluate.
        lead_time (int): Number of future frames to predict.
        enable_wandb (bool): Whether to log results to Weights & Biases.
        wandb_instance: The active wandb instance.
        debug_print_prefix (str): Prefix for print statements.
        plots_folder (str): Directory to save plots and animations.
        cartopy_features (list): List of features to draw on maps.
        ema_model_evaluated (bool): Flag indicating if the EMA model is being used.
        diffusion_buffers (dict): Pre-computed diffusion schedule tensors.
        timesteps (int): Total number of diffusion timesteps.
        clip_denoised (bool): Whether to clip the denoised output.
        use_ddim (bool): If True, use DDIM sampling; otherwise, use DDPM.
        ddim_num_steps (int): Number of steps for DDIM sampling.
        ddim_eta (float): DDIM eta parameter (0 for deterministic, 1 for DDPM-like).
        batch_size_autoencoder (int, optional): Batch size for autoencoder inference.

    Returns:
        dict or None: A dictionary of computed metrics on the main process, otherwise None.
    """
    # This function runs on all ranks when distributed.
    results = None
    model.eval()  # Ensure model is in eval mode
    if ae_model:
        ae_model.eval()

    # Determine rank info for controlled logging/progress
    is_distributed = dist.is_available() and dist.is_initialized()
    rank = dist.get_rank() if is_distributed else 0
    is_rank0 = (not is_distributed) or (rank == 0)

    with torch.no_grad():
        metrics_accumulators = [
            MetricsAccumulator(
                lead_time=lt,  # Use lt index correctly
                thresholds=thresholds,
                pool_size=16,
                compute_mse=True,
                compute_threshold=True,
                compute_apsd=False,
                compute_ssim=True,
                ssim_data_range=255.0,
                device=device,
            )
            for lt in range(lead_time)  # Correct range
        ]
        count = 0
        y_pred_batches = []  # Changed variable name
        y_true_batches = []  # Changed variable name
        last_metadata = None

        # Determine per-rank target samples based on requested total batches
        if dist.is_available() and dist.is_initialized():
            world_size = dist.get_world_size()
            local_target_batches = int(
                math.ceil(partial_evaluation_batches / world_size)
            )
        else:
            local_target_batches = partial_evaluation_batches
        local_target_samples = local_target_batches * val_sample_loader.batch_size

        eval_bar = tqdm(
            val_sample_loader,
            desc=f"Partial Eval Epoch {epoch}",
            leave=False,
            disable=not is_rank0,
        )

        for batch in eval_bar:
            x_cond, x_true, metadata = batch
            last_metadata = metadata
            # --- Preprocessing on Rank 0 Device ---
            x_cond = x_cond.to(device, non_blocking=True)
            x_true = x_true.to(
                device, non_blocking=True
            )  # Keep true on device for potential ops

            B, C, T_in, H, W = x_cond.shape
            x_cond = x_cond.permute(0, 2, 1, 3, 4).reshape(B * T_in, C, H, W)

            if normalized_autoencoder:
                x_cond = x_cond / 255.0

            # --- Autoencoder Encoding ---
            if ae_model:
                encoded_chunks = []
                bs_ae = (
                    batch_size_autoencoder
                    if batch_size_autoencoder is not None
                    else x_cond.shape[0]
                )
                for i in range(0, x_cond.shape[0], bs_ae):
                    chunk = x_cond[i : i + bs_ae]
                    with torch.amp.autocast(
                        device_type=device.type, enabled=False
                    ):  # Use AMP if enabled
                        encoded_chunk = ae_model.encode(chunk)
                    encoded_chunk = encoded_chunk.latent_dist.mode()
                    encoded_chunks.append(encoded_chunk)
                x_cond = torch.cat(encoded_chunks, dim=0)
            else:
                # Should not happen if partial eval is enabled, but handle defensively
                print(
                    f"{debug_print_prefix}Warning: AE model not available for encoding in partial eval."
                )
                # Need latent shape info - this path might fail without AE
                latent_channels, latent_H, latent_W = (
                    4,
                    H // 8,
                    W // 8,
                )  # Placeholder, adjust!
                x_cond = torch.randn(
                    B * T_in, latent_channels, latent_H, latent_W, device=device
                )  # Dummy data

            if asinh_transform:
                x_cond = torch.asinh(x_cond)

            latent_channels, latent_H, latent_W = (
                x_cond.shape[1],
                x_cond.shape[2],
                x_cond.shape[3],
            )
            x_cond = x_cond.reshape(
                B, T_in, latent_channels, latent_H, latent_W
            ).permute(0, 2, 1, 3, 4)

            # --- Normalization (using the main model's stats) ---
            x_cond = model.normalize(x_cond)
            x_cond = x_cond.permute(0, 2, 3, 4, 1)  # B, Tin, Hz, Wz, Cz

            # --- Generative Step (ODE Integration) ---
            B, Tin, Hz, Wz, Cz = x_cond.shape
            x_true = x_true.squeeze(
                1
            )  # remove channel dim [B, T_future, H_true, W_true]
            T_future = x_true.shape[1]
            H_true, W_true = x_true.shape[2], x_true.shape[3]

            x_true_downsampled_example = torch.zeros(
                (B, T_future, Hz, Wz, Cz), device=device
            )

            if use_ddim:
                x_pred_sample = ddim_sample_loop(
                    model,
                    shape=x_true_downsampled_example.shape,
                    cond=x_cond,
                    diffusion_buffers=diffusion_buffers,
                    device=device,
                    total_timesteps=timesteps,
                    ddim_num_steps=ddim_num_steps,
                    eta=ddim_eta,
                    clip_denoised=clip_denoised,
                )
            else:
                x_pred_sample = p_sample_loop(
                    model,
                    shape=x_true_downsampled_example.shape,
                    cond=x_cond,
                    diffusion_buffers=diffusion_buffers,
                    device=device,
                    timesteps=timesteps,
                    clip_denoised=clip_denoised,
                )

            x_pred = x_pred_sample.unsqueeze(1)  # Add sample dim: [B, 1, T, Hz, Wz, Cz]

            # --- Postprocessing on Rank 0 Device ---
            x_pred_np = x_pred.cpu().numpy()  # Move to CPU for numpy ops
            x_true_np = x_true.cpu().numpy()  # Move true to CPU

            x_pred_np = (x_pred_np * model.std.numpy() + model.mean.numpy()).astype(
                np.float32
            )  # Denormalize using model's stats
            if asinh_transform:
                x_pred_np = np.sinh(x_pred_np)

            B, S, T, H_latent, W_latent, C_latent = x_pred_np.shape
            x_pred_np = x_pred_np.reshape(B * S * T, H_latent, W_latent, C_latent)

            x_pred_tensor = torch.from_numpy(x_pred_np).to(
                device
            )  # Back to device for AE
            x_pred_tensor = x_pred_tensor.permute(0, 3, 1, 2)  # [BST, C, Hl, Wl]

            # --- Autoencoder Decoding ---
            if ae_model:
                decoded_chunks = []
                bs_ae = (
                    batch_size_autoencoder
                    if batch_size_autoencoder is not None
                    else x_pred_tensor.shape[0]
                )
                for i in range(0, x_pred_tensor.shape[0], bs_ae):
                    chunk = x_pred_tensor[i : i + bs_ae]
                    with torch.amp.autocast(device_type=device.type, enabled=False):
                        decoded_chunk = ae_model.decode(chunk)
                    decoded_chunk = decoded_chunk.sample
                    decoded_chunks.append(decoded_chunk)
                x_pred_tensor = torch.cat(
                    decoded_chunks, dim=0
                )  # [BST, 1, H_true, W_true]
            else:
                # Handle case without AE - generate dummy output of correct shape
                print(
                    f"{debug_print_prefix}Warning: AE model not available for decoding in partial eval."
                )
                x_pred_tensor = (
                    torch.rand(B * S * T, 1, H_true, W_true, device=device) * 255.0
                )

            if normalized_autoencoder:
                x_pred_tensor = x_pred_tensor * 255.0

            if torch.isnan(x_pred_tensor).any():
                print(f"{debug_print_prefix}NaN values found in x_pred after decode")

            x_pred_tensor = x_pred_tensor.reshape(B, S, T, 1, H_true, W_true)
            x_pred_tensor = x_pred_tensor.permute(
                0, 1, 2, 4, 5, 3
            )  # [B, S, T, H, W, C]
            if x_pred_tensor.shape[-1] == 1:
                x_pred_tensor = x_pred_tensor.squeeze(-1)  # [B, S, T, H, W]

            x_pred_np = x_pred_tensor.cpu().numpy().astype(np.float32)

            # --- Append batch results ---
            y_pred_batches.append(x_pred_np)  # Shape: [B, S, T, H, W]
            y_true_batches.append(x_true_np)  # Shape: [B, T, H, W]

            count += B  # Increment count by batch size
            if count >= local_target_samples:  # Per-rank target based on world size
                break
        eval_bar.close()

        # If distributed, reduce accumulators across ranks before computing
        def _allreduce_metrics_accumulators(
            metrics_accumulators, thresholds_list, device
        ):
            if not (dist.is_available() and dist.is_initialized()):
                return
            float_scalar_fields = [
                "mse_sum",
                "apsd_sum",
                "ssim_sum",
                "mse_from_mean_sum",
                "apsd_from_mean_sum",
                "ssim_from_mean_sum",
                "crps_sum",
            ]
            int_scalar_fields = [
                "mse_count",
                "apsd_count",
                "ssim_score_count",
                "mse_from_mean_count",
                "apsd_from_mean_count",
                "ssim_from_mean_frames_count",
                "crps_count",
            ]
            float_dict_fields = [
                "csi_sum",
                "pod_sum",
                "far_sum",
                "hss_sum",
                "csi_pooled_sum",
            ]
            int_dict_fields = [
                "csi_count",
                "pod_count",
                "far_count",
                "hss_count",
                "csi_pooled_count",
                "csi_from_mean_hits",
                "csi_from_mean_misses",
                "csi_from_mean_false_alarms",
                "pixel_contingency_from_mean_hits",
                "pixel_contingency_from_mean_misses",
                "pixel_contingency_from_mean_false_alarms",
                "pixel_contingency_from_mean_correct_negatives",
                "csi_pooled_from_mean_hits",
                "csi_pooled_from_mean_misses",
                "csi_pooled_from_mean_false_alarms",
            ]
            for acc in metrics_accumulators:
                for name in float_scalar_fields:
                    t = torch.tensor(
                        getattr(acc, name), device=device, dtype=torch.float64
                    )
                    dist.all_reduce(t, op=dist.ReduceOp.SUM)
                    setattr(acc, name, t.item())
                for name in int_scalar_fields:
                    t = torch.tensor(
                        float(getattr(acc, name)), device=device, dtype=torch.float64
                    )
                    dist.all_reduce(t, op=dist.ReduceOp.SUM)
                    setattr(acc, name, int(t.item()))
                for name in float_dict_fields:
                    keys = list(getattr(acc, name).keys())
                    arr = [getattr(acc, name)[th] for th in keys]
                    t = torch.tensor(arr, device=device, dtype=torch.float64)
                    dist.all_reduce(t, op=dist.ReduceOp.SUM)
                    for i, th in enumerate(keys):
                        getattr(acc, name)[th] = t[i].item()
                for name in int_dict_fields:
                    keys = list(getattr(acc, name).keys())
                    arr = [float(getattr(acc, name)[th]) for th in keys]
                    t = torch.tensor(arr, device=device, dtype=torch.float64)
                    dist.all_reduce(t, op=dist.ReduceOp.SUM)
                    for i, th in enumerate(keys):
                        getattr(acc, name)[th] = int(t[i].item())

        has_local_data = len(y_pred_batches) > 0
        if has_local_data:
            # Concatenate all batch results for this rank
            y_pred_array = np.concatenate(
                y_pred_batches, axis=0
            )  # [Local_B, S, T, H, W]
            y_true_array = np.concatenate(y_true_batches, axis=0)  # [Local_B, T, H, W]

            # Post-process predictions (clamp, etc.)
            y_pred_array = post_process_samples(
                y_pred_array, clamp_min=0.0, clamp_max=255.0
            )

            # Update local metrics accumulators
            for metrics_accumulator in metrics_accumulators:
                metrics_accumulator.update(y_true_array, y_pred_array)

        # All-reduce accumulators across ranks
        _allreduce_metrics_accumulators(metrics_accumulators, thresholds, device)

        # Compute final metrics on rank 0 only
        is_distributed = dist.is_available() and dist.is_initialized()
        is_rank0 = (not is_distributed) or (dist.get_rank() == 0)
        if is_rank0:
            results = calculate_metrics(
                num_lead_times=lead_time,
                metrics_accumulators=metrics_accumulators,
                thresholds=thresholds,
            )

        # Print results
        if is_rank0 and results is not None:
            EMA_SUFFIX = "(EMA)" if ema_model_evaluated else ""
            print(
                f"{debug_print_prefix}Partial Results {EMA_SUFFIX}: MSE: {results.get('mse_mean', 'N/A')}, "
                f"CSI-M: {results.get('csi_m', 'N/A')}, CSI (pool)-M: {results.get('csi_pool_m', 'N/A')}, "
                f"HSS-M: {results.get('hss_m', 'N/A')}, FAR-M: {results.get('far_m', 'N/A')}, POD-M: {results.get('pod_m', 'N/A')}, "
                f"SSIM-M: {results.get('ssim_mean', 'N/A')}"
            )

        # Log results to WandB
        EMA_SUFFIX_WANDB = "_EMA" if ema_model_evaluated else ""
        if is_rank0 and results is not None and enable_wandb and wandb_instance:
            log_dict = {
                f"partial_mse{EMA_SUFFIX_WANDB}": results["mse_mean"],
                f"partial_csi_m{EMA_SUFFIX_WANDB}": results["csi_m"],
                f"partial_csi_pool_m{EMA_SUFFIX_WANDB}": results["csi_pool_m"],
                f"partial_hss_m{EMA_SUFFIX_WANDB}": results["hss_m"],
                f"partial_far_m{EMA_SUFFIX_WANDB}": results["far_m"],
                f"partial_pod_m{EMA_SUFFIX_WANDB}": results["pod_m"],
                f"partial_ssim{EMA_SUFFIX_WANDB}": results["ssim_mean"],
            }
            wandb_instance.log(log_dict, step=global_step)

        # --- Plotting Animations (Still on Rank 0) ---
        if (
            is_rank0
            and results is not None
            and len(y_pred_batches) > 0
            and len(y_true_batches) > 0
            and last_metadata is not None
        ):
            try:
                # Plot only the last sample from this rank's processed set
                y_pred_array_plot = np.concatenate(y_pred_batches, axis=0)
                y_true_array_plot = np.concatenate(y_true_batches, axis=0)
                sample_pred_plot = y_pred_array_plot[-3, 0]
                sample_true_plot = y_true_array_plot[-3]

                epoch_anim_folder_suffix = "_ema" if ema_model_evaluated else ""
                epoch_anim_folder = os.path.join(
                    plots_folder,
                    f"animations{epoch_anim_folder_suffix}",
                    f"Epoch_{epoch}",
                )
                os.makedirs(epoch_anim_folder, exist_ok=True)

                fig1 = plt.figure()
                anim1 = make_animation(
                    sample_pred_plot,
                    last_metadata[0],
                    title=f"Output Epoch {epoch}{EMA_SUFFIX}",
                    fig=fig1,
                    cartopy_features=cartopy_features,
                )
                anim1_path = os.path.join(
                    epoch_anim_folder, f"output_test_animation_sample0.gif"
                )
                anim1.save(anim1_path, writer="imagemagick", fps=6)
                plt.close(fig1)

                fig2 = plt.figure()
                anim2 = make_animation(
                    sample_true_plot,
                    last_metadata[0],
                    title=f"Target Epoch {epoch}",
                    fig=fig2,
                    cartopy_features=cartopy_features,
                )
                anim2_path = os.path.join(
                    epoch_anim_folder, "target_test_animation.gif"
                )
                anim2.save(anim2_path, writer="imagemagick", fps=6)
                plt.close(fig2)

                if enable_wandb and wandb_instance:
                    wandb_instance.log(
                        {
                            f"Prediction Animation{EMA_SUFFIX_WANDB}": wandb.Video(
                                anim1_path, fps=6, format="gif"
                            ),
                            "Target Animation": wandb.Video(
                                anim2_path, fps=6, format="gif"
                            ),
                        },
                        step=global_step,
                    )
            except Exception as e:
                print(f"{debug_print_prefix} Error creating or saving animations: {e}")

    return results  # Return metrics dictionary on rank 0; others return None


if __name__ == "__main__":
    # Check if running distributed
    is_distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
    if (
        not is_distributed
        and torch.cuda.is_available()
        and torch.cuda.device_count() > 1
    ):
        print("WARNING: Multiple GPUs available but not running in distributed mode.")
        print(
            "Use `torchrun --standalone --nnodes=1 --nproc_per_node=NUM_GPUS your_script_name.py [args]`"
        )

    main()
