import argparse
import datetime
import gc
import os
import random
import sys

from omegaconf import OmegaConf

sys.path.append(os.getcwd())  # Add cwd to path
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import matplotlib.pyplot as plt
import namegenerator
import numpy as np
import torch
import wandb
import h5py
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchdiffeq import odeint_adjoint as odeint
from tqdm import tqdm

from common.metrics.metrics_streaming_probabilistic import MetricsAccumulator
from common.models.flowcast.cuboid_transformer_unet import (
    CuboidTransformerUNet,
)
from common.utils.utils import calculate_metrics
from experiments.arso.dataset.arsodataset import (
    ArsoH5Dataset,
    post_process_samples,
)
from experiments.arso.dataset.arsodataset_autoencoder import (
    pad_to_multiple_of_16_5d,
    remove_padding,
)
from experiments.arso.display.cartopy import make_animation_arso


# Create argument parser
parser = argparse.ArgumentParser(description="Script for testing FlowCast model.")

# Add arguments with default values
parser.add_argument(
    "--artifacts_folder",
    type=str,
    default="artifacts/arso/flowcast/2025-08-29_18-18-48_flowcast_ddp_nippy-brown-newt_main",
    help="Artifacts folder to load model from",
)
parser.add_argument(
    "--config",
    type=str,
    default="experiments/arso/runner/flowcast/flowcast_config.yaml",
    help="Path to the configuration file.",
)
parser.add_argument(
    "--data_file",
    type=str,
    default="datasets/arso/final_sequence_data/sequence_data_ds1.h5",
    help="Data file to use.",
)

# Parse arguments
args = parser.parse_args()
config = OmegaConf.load(args.config)


# --- Assign variables from config ---
# run_params
DEBUG_MODE = config.run_params.debug_mode
ENABLE_WANDB = config.run_params.enable_wandb
RUN_STRING = config.run_params.run_string

# test_params
BATCH_SIZE = config.test_params.micro_batch_size
NUM_WORKERS = config.test_params.num_workers
PROBABILISTIC_SAMPLES = config.test_params.probabilistic_samples
BATCH_SIZE_AUTOENCODER = config.test_params.batch_size_autoencoder
CARTOPY_FEATURES = config.test_params.cartopy_features
THRESHOLDS = np.array(config.test_params.thresholds, dtype=np.float32)
EULER_STEPS = config.test_params.euler_steps

# data_params
DATA_FILE = args.data_file
ASINH_TRANSFORM = getattr(config.data_params, "asinh_transform", False)

# autoencoder_params
PRELOAD_AE_MODEL = config.autoencoder_params.autoencoder_checkpoint
NORMALIZED_AUTOENCODER = config.autoencoder_params.normalized_autoencoder
LATENT_CHANNELS = config.autoencoder_params.latent_channels
NORM_NUM_GROUPS = config.autoencoder_params.norm_num_groups
LAYERS_PER_BLOCK = config.autoencoder_params.layers_per_block
ACT_FN = config.autoencoder_params.act_fn
BLOCK_OUT_CHANNELS = config.autoencoder_params.block_out_channels
DOWN_BLOCK_TYPES = config.autoencoder_params.down_block_types
UP_BLOCK_TYPES = config.autoencoder_params.up_block_types


RUN_ID = (
    datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    + "_"
    + RUN_STRING
    + "_"
    + namegenerator.gen()
)
ARTIFACTS_FOLDER = args.artifacts_folder
DEBUG_PRINT_PREFIX = "[DEBUG] " if DEBUG_MODE else ""


# Load the config for the latent model
model_config = OmegaConf.to_object(config.latent_model)

# Flowcast
BASE_UNITS = model_config["base_units"]
SCALE_ALPHA = model_config["scale_alpha"]
NUM_HEADS = model_config["num_heads"]
ATTN_DROP = model_config["attn_drop"]
PROJ_DROP = model_config["proj_drop"]
FFN_DROP = model_config["ffn_drop"]
DOWNSAMPLE = model_config["downsample"]
DOWNSAMPLE_TYPE = model_config["downsample_type"]
UPSAMPLE_TYPE = model_config["upsample_type"]
UPSAMPLE_KERNEL_SIZE = model_config["upsample_kernel_size"]
DEPTH = model_config["depth"]
BLOCK_ATTN_PATTERNS = [model_config["self_pattern"]] * len(DEPTH)
NUM_GLOBAL_VECTORS = model_config["num_global_vectors"]
USE_GLOBAL_VECTOR_FFN = model_config["use_global_vector_ffn"]
USE_GLOBAL_SELF_ATTN = model_config["use_global_self_attn"]
SEPARATE_GLOBAL_QKV = model_config["separate_global_qkv"]
GLOBAL_DIM_RATIO = model_config["global_dim_ratio"]
SELF_PATTERN = model_config["self_pattern"]
FFN_ACTIVATION = model_config["ffn_activation"]
GATED_FFN = model_config["gated_ffn"]
NORM_LAYER = model_config["norm_layer"]
PADDING_TYPE = model_config["padding_type"]
CHECKPOINT_LEVEL = model_config["checkpoint_level"]
POS_EMBED_TYPE = model_config["pos_embed_type"]
USE_RELATIVE_POS = model_config["use_relative_pos"]
SELF_ATTN_USE_FINAL_PROJ = model_config["self_attn_use_final_proj"]
ATTN_LINEAR_INIT_MODE = model_config["attn_linear_init_mode"]
FFN_LINEAR_INIT_MODE = model_config["ffn_linear_init_mode"]
FFN2_LINEAR_INIT_MODE = model_config["ffn2_linear_init_mode"]
ATTN_PROJ_LINEAR_INIT_MODE = model_config["attn_proj_linear_init_mode"]
CONV_INIT_MODE = model_config["conv_init_mode"]
DOWN_UP_LINEAR_INIT_MODE = model_config["down_up_linear_init_mode"]
GLOBAL_PROJ_LINEAR_INIT_MODE = model_config["global_proj_linear_init_mode"]
NORM_INIT_MODE = model_config["norm_init_mode"]
TIME_EMBED_CHANNELS_MULT = model_config["time_embed_channels_mult"]
TIME_EMBED_USE_SCALE_SHIFT_NORM = model_config["time_embed_use_scale_shift_norm"]
TIME_EMBED_DROPOUT = model_config["time_embed_dropout"]
UNET_RES_CONNECT = model_config["unet_res_connect"]


# Autoencoder - These are now read from the config file, so direct assignment is removed.

print(f"{DEBUG_PRINT_PREFIX}Debug Mode: {DEBUG_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Data File: {DATA_FILE}")
print(f"{DEBUG_PRINT_PREFIX}Normalized Autoencoder: {NORMALIZED_AUTOENCODER}")
print(f"{DEBUG_PRINT_PREFIX}Batch Size: {BATCH_SIZE}")
print(f"{DEBUG_PRINT_PREFIX}Number of Workers: {NUM_WORKERS}")
print(f"{DEBUG_PRINT_PREFIX}Thresholds: {THRESHOLDS}")
print(f"{DEBUG_PRINT_PREFIX}Euler Steps: {EULER_STEPS}")
print(f"{DEBUG_PRINT_PREFIX}Batch Size Interpolation: {BATCH_SIZE_AUTOENCODER}")
print(f"{DEBUG_PRINT_PREFIX}Probabilistic Samples: {PROBABILISTIC_SAMPLES}")
print(f"{DEBUG_PRINT_PREFIX}ASIHN Transform: {ASINH_TRANSFORM}")
print(f"{DEBUG_PRINT_PREFIX}Preload AE Model: {PRELOAD_AE_MODEL}")
print(f"{DEBUG_PRINT_PREFIX}Latent Channels: {LATENT_CHANNELS}")
print(f"{DEBUG_PRINT_PREFIX}Norm Num Groups: {NORM_NUM_GROUPS}")
print(f"{DEBUG_PRINT_PREFIX}Layers Per Block: {LAYERS_PER_BLOCK}")
print(f"{DEBUG_PRINT_PREFIX}Activation Function: {ACT_FN}")
print(f"{DEBUG_PRINT_PREFIX}Block Out Channels: {BLOCK_OUT_CHANNELS}")
print(f"{DEBUG_PRINT_PREFIX}Down Block Types: {DOWN_BLOCK_TYPES}")
print(f"{DEBUG_PRINT_PREFIX}Up Block Types: {UP_BLOCK_TYPES}")
print(f"--------- {DEBUG_PRINT_PREFIX}Flowcast Config ---------")
print(f"{DEBUG_PRINT_PREFIX}Base Units: {BASE_UNITS}")
print(f"{DEBUG_PRINT_PREFIX}Scale Alpha: {SCALE_ALPHA}")
print(f"{DEBUG_PRINT_PREFIX}Depth: {DEPTH}")
print(f"{DEBUG_PRINT_PREFIX}Block Attn Patterns: {BLOCK_ATTN_PATTERNS}")

print(f"{DEBUG_PRINT_PREFIX}Downsample: {DOWNSAMPLE}")
print(f"{DEBUG_PRINT_PREFIX}Downsample Type: {DOWNSAMPLE_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Upsample Type: {UPSAMPLE_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Num Global Vectors: {NUM_GLOBAL_VECTORS}")
print(f"{DEBUG_PRINT_PREFIX}ATTN_PROJ_LINEAR_INIT_MODE: {ATTN_PROJ_LINEAR_INIT_MODE}")
print(
    f"{DEBUG_PRINT_PREFIX}Global Proj Linear Init Mode: {GLOBAL_PROJ_LINEAR_INIT_MODE}"
)
print(f"{DEBUG_PRINT_PREFIX}Use Global Vector FFN: {USE_GLOBAL_VECTOR_FFN}")
print(f"{DEBUG_PRINT_PREFIX}Use Global Self Attn: {USE_GLOBAL_SELF_ATTN}")
print(f"{DEBUG_PRINT_PREFIX}Separate Global QKV: {SEPARATE_GLOBAL_QKV}")
print(f"{DEBUG_PRINT_PREFIX}Global Dim Ratio: {GLOBAL_DIM_RATIO}")
print(f"{DEBUG_PRINT_PREFIX}Self Pattern: {SELF_PATTERN}")
print(f"{DEBUG_PRINT_PREFIX}Attn Drop: {ATTN_DROP}")
print(f"{DEBUG_PRINT_PREFIX}Proj Drop: {PROJ_DROP}")
print(f"{DEBUG_PRINT_PREFIX}FFN Drop: {FFN_DROP}")
print(f"{DEBUG_PRINT_PREFIX}Num Heads: {NUM_HEADS}")
print(f"{DEBUG_PRINT_PREFIX}FFN Activation: {FFN_ACTIVATION}")
print(f"{DEBUG_PRINT_PREFIX}Gated FFN: {GATED_FFN}")
print(f"{DEBUG_PRINT_PREFIX}Norm Layer: {NORM_LAYER}")
print(f"{DEBUG_PRINT_PREFIX}Padding Type: {PADDING_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Pos Embed Type: {POS_EMBED_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Use Relative Pos: {USE_RELATIVE_POS}")
print(f"{DEBUG_PRINT_PREFIX}Self Attn Use Final Proj: {SELF_ATTN_USE_FINAL_PROJ}")
print(f"{DEBUG_PRINT_PREFIX}Checkpoint Level: {CHECKPOINT_LEVEL}")
print(f"{DEBUG_PRINT_PREFIX}Attn Linear Init Mode: {ATTN_LINEAR_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}FFN Linear Init Mode: {FFN_LINEAR_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Conv Init Mode: {CONV_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Down Up Linear Init Mode: {DOWN_UP_LINEAR_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Norm Init Mode: {NORM_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Time Embed Channels Mult: {TIME_EMBED_CHANNELS_MULT}")
print(
    f"{DEBUG_PRINT_PREFIX}Time Embed Use Scale Shift Norm: {TIME_EMBED_USE_SCALE_SHIFT_NORM}"
)
print(f"{DEBUG_PRINT_PREFIX}Time Embed Dropout: {TIME_EMBED_DROPOUT}")
print(f"{DEBUG_PRINT_PREFIX}UNET Res Connect: {UNET_RES_CONNECT}")
print(f"{DEBUG_PRINT_PREFIX}Batch Size Autoencoder: {BATCH_SIZE_AUTOENCODER}")

# Plots inside the artifacts folder
PLOTS_FOLDER = ARTIFACTS_FOLDER + "/plots"
os.makedirs(PLOTS_FOLDER, exist_ok=True)
# Make Animations folder inside the plots folder
ANIMATIONS_FOLDER = PLOTS_FOLDER + "/animations"
os.makedirs(ANIMATIONS_FOLDER, exist_ok=True)
# Make Metrics folder inside the plots folder
METRICS_FOLDER = PLOTS_FOLDER + "/metrics"
os.makedirs(METRICS_FOLDER, exist_ok=True)


# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{DEBUG_PRINT_PREFIX}Using device: {device}")
if device.type == "cpu":
    print(DEBUG_PRINT_PREFIX + "CPU is used")
else:
    print(f"{DEBUG_PRINT_PREFIX}Number of GPUs available: {torch.cuda.device_count()}")
    if torch.cuda.device_count() > 1:
        print(f"{DEBUG_PRINT_PREFIX}Using {torch.cuda.device_count()} GPUs!")

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

# Create a directory for saving models inside the artifacts folder
MODEL_SAVE_DIR = ARTIFACTS_FOLDER + "/models"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, "early_stopping_model" + ".pt")


# Add helper functions after the imports
def safe_encode(model, x):
    """Safely encode using the model, handling DataParallel wrapper."""
    if isinstance(model, torch.nn.DataParallel):
        return model.module.encode(x)
    return model.encode(x)


def safe_decode(model, x):
    """Safely decode using the model, handling DataParallel wrapper."""
    if isinstance(model, torch.nn.DataParallel):
        return model.module.decode(x)
    return model.decode(x)


if ENABLE_WANDB:
    wandb.init(
        project="arso-nowcasting-testing-cfm",
        name=RUN_ID,
        config={
            "batch_size": BATCH_SIZE,
            "num_workers": NUM_WORKERS,
            "model": "flowcast",
            "model_save_path": MODEL_SAVE_PATH,
        },
    )

PRELOAD_MODEL = MODEL_SAVE_PATH if os.path.exists(MODEL_SAVE_PATH) else None
if PRELOAD_MODEL is None:
    raise FileNotFoundError(f"Model not found at {MODEL_SAVE_PATH}")
else:
    print(f"{DEBUG_PRINT_PREFIX}Model found at {MODEL_SAVE_PATH}")

    with h5py.File(DATA_FILE, "r") as hf:
        num_samples = len(hf["zm_IN"])
    indices = np.arange(num_samples)
    train_frac, val_frac, test_frac = 0.6, 0.2, 0.2
    X_train_val_idx, X_test_idx = train_test_split(
        indices, train_size=(train_frac + val_frac), random_state=42, shuffle=False
    )

    test_dataset = ArsoH5Dataset(
        h5_file_path=DATA_FILE,
        indices=X_test_idx,
        channel_last=False,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS if not DEBUG_MODE else 0,
        pin_memory=True if not DEBUG_MODE else False,
    )
    # Testing the Model

    if not os.path.exists(PRELOAD_AE_MODEL):
        raise FileNotFoundError(f"Model not found at {PRELOAD_AE_MODEL}")

    from diffusers.models.autoencoders import AutoencoderKL

    ae_model = AutoencoderKL(
        in_channels=1,
        out_channels=1,
        down_block_types=DOWN_BLOCK_TYPES,
        up_block_types=UP_BLOCK_TYPES,
        block_out_channels=BLOCK_OUT_CHANNELS,
        act_fn=ACT_FN,
        latent_channels=LATENT_CHANNELS,
        norm_num_groups=NORM_NUM_GROUPS,
        layers_per_block=LAYERS_PER_BLOCK,
    )

    checkpoint = torch.load(PRELOAD_AE_MODEL, map_location=device, weights_only=False)
    model_state_dict = checkpoint["model_state_dict"]
    # Filter out the module prefix
    model_state_dict = {
        k.replace("module.", ""): v for k, v in model_state_dict.items()
    }

    ae_model.load_state_dict(model_state_dict)
    ae_model = ae_model.to(device)
    if torch.cuda.device_count() > 1:
        ae_model = torch.nn.DataParallel(ae_model)
    ae_model.eval()

    # --- Determine shapes from data ---
    # Get one batch to determine shapes
    x_cond_raw_shape, x_true_raw_shape = next(iter(test_loader))

    T_in = x_cond_raw_shape.shape[2]
    LEAD_TIME = x_true_raw_shape.shape[2]

    x_cond_padded, _ = pad_to_multiple_of_16_5d(x_cond_raw_shape)

    # To get latent shape, we encode one frame from the sequence
    x_cond_for_shape = x_cond_padded[:, :, 0, :, :].to(device)

    if NORMALIZED_AUTOENCODER:
        x_cond_for_shape = x_cond_for_shape / 57.0

    with torch.no_grad():
        encoded_obj = safe_encode(ae_model, x_cond_for_shape)

    input_encoded_shape = encoded_obj.latent_dist.mode().shape

    latent_C = input_encoded_shape[1]
    latent_H = input_encoded_shape[2]
    latent_W = input_encoded_shape[3]

    print(f"Input shape: {x_cond_raw_shape.shape}")
    print(f"Encoded shape: {input_encoded_shape}")
    print(f"Output shape: {x_true_raw_shape.shape}")

    # Load the best model
    checkpoint = torch.load(MODEL_SAVE_PATH, map_location="cpu", weights_only=False)
    checkpoint_model_state_dict = checkpoint["model_state_dict"]
    # Filter out the module prefix
    checkpoint_model_state_dict = {
        k.replace("module.", ""): v for k, v in checkpoint_model_state_dict.items()
    }
    checkpoint_mean = checkpoint["mean"]
    checkpoint_std = checkpoint["std"]

    # Create the model
    input_shape_flowcast = (T_in, latent_H, latent_W, latent_C)
    output_shape_flowcast = (LEAD_TIME, latent_H, latent_W, latent_C)

    loaded_model = CuboidTransformerUNet(
        input_shape=input_shape_flowcast,
        target_shape=output_shape_flowcast,  # Exclude first dimension (batch size)
        base_units=BASE_UNITS,
        block_units=None,  # multiply by 2 when downsampling in each layer
        scale_alpha=SCALE_ALPHA,
        num_heads=NUM_HEADS,
        attn_drop=ATTN_DROP,
        proj_drop=PROJ_DROP,
        ffn_drop=FFN_DROP,
        downsample=DOWNSAMPLE,
        downsample_type=DOWNSAMPLE_TYPE,
        upsample_type=UPSAMPLE_TYPE,
        upsample_kernel_size=UPSAMPLE_KERNEL_SIZE,
        depth=DEPTH,
        block_attn_patterns=BLOCK_ATTN_PATTERNS,
        # global vectors
        num_global_vectors=NUM_GLOBAL_VECTORS,
        use_global_vector_ffn=USE_GLOBAL_VECTOR_FFN,
        use_global_self_attn=USE_GLOBAL_SELF_ATTN,
        separate_global_qkv=SEPARATE_GLOBAL_QKV,
        global_dim_ratio=GLOBAL_DIM_RATIO,
        # misc
        ffn_activation=FFN_ACTIVATION,
        gated_ffn=GATED_FFN,
        norm_layer=NORM_LAYER,
        padding_type=PADDING_TYPE,
        checkpoint_level=CHECKPOINT_LEVEL,
        pos_embed_type=POS_EMBED_TYPE,
        use_relative_pos=USE_RELATIVE_POS,
        self_attn_use_final_proj=SELF_ATTN_USE_FINAL_PROJ,
        # initialization
        attn_linear_init_mode=ATTN_LINEAR_INIT_MODE,
        ffn_linear_init_mode=FFN_LINEAR_INIT_MODE,
        ffn2_linear_init_mode=FFN2_LINEAR_INIT_MODE,
        attn_proj_linear_init_mode=ATTN_PROJ_LINEAR_INIT_MODE,
        conv_init_mode=CONV_INIT_MODE,
        down_linear_init_mode=DOWN_UP_LINEAR_INIT_MODE,
        up_linear_init_mode=DOWN_UP_LINEAR_INIT_MODE,
        global_proj_linear_init_mode=GLOBAL_PROJ_LINEAR_INIT_MODE,
        norm_init_mode=NORM_INIT_MODE,
        # timestep embedding for diffusion
        time_embed_channels_mult=TIME_EMBED_CHANNELS_MULT,
        time_embed_use_scale_shift_norm=TIME_EMBED_USE_SCALE_SHIFT_NORM,
        time_embed_dropout=TIME_EMBED_DROPOUT,
        unet_res_connect=UNET_RES_CONNECT,
        mean=checkpoint_mean,
        std=checkpoint_std,
    )
    loaded_model.load_state_dict(checkpoint_model_state_dict)
    loaded_model = loaded_model.to(device)
    if torch.cuda.device_count() > 1:
        loaded_model = torch.nn.DataParallel(loaded_model)
    loaded_model.eval()

    # Create a metrics accumulator for each lead time
    metrics_accumulators = [
        MetricsAccumulator(
            lead_time=lt,
            thresholds=THRESHOLDS,
            pool_size=16,
            compute_mse=True,
            compute_threshold=True,
            compute_apsd=False,
            compute_ssim=True,
            ssim_data_range=57.0,
            device=device,
        )
        for lt in range(LEAD_TIME)
    ]

    test_bar = tqdm(test_loader, desc="Testing Model")
    count = 0
    y_pred_batches = []
    y_true_batches = []
    for idx, batch in enumerate(test_bar):
        x_cond_raw, x_true_raw = batch

        # --- Preprocessing ---
        x_cond_raw = x_cond_raw.to(device, non_blocking=True)
        x_true_raw = x_true_raw.to(device, non_blocking=True)

        x_cond_padded, padding_info = pad_to_multiple_of_16_5d(x_cond_raw)

        B, C, T_in, H_p, W_p = x_cond_padded.shape
        x_cond_ae_input = x_cond_padded.permute(0, 2, 1, 3, 4).reshape(
            B * T_in, C, H_p, W_p
        )

        if NORMALIZED_AUTOENCODER:
            x_cond_ae_input = x_cond_ae_input / 57.0

        with torch.no_grad():
            encoded_obj = safe_encode(ae_model, x_cond_ae_input)
        x_cond_latent = encoded_obj.latent_dist.mode()

        if ASINH_TRANSFORM:
            x_cond_latent = torch.asinh(x_cond_latent).detach()

        latent_C, latent_H, latent_W = (
            x_cond_latent.shape[1],
            x_cond_latent.shape[2],
            x_cond_latent.shape[3],
        )
        x_cond = x_cond_latent.reshape(B, T_in, latent_C, latent_H, latent_W).permute(
            0, 2, 1, 3, 4
        )

        x_cond = (
            loaded_model.module.normalize(x_cond)
            if isinstance(loaded_model, torch.nn.DataParallel)
            else loaded_model.normalize(x_cond)
        )
        x_cond = x_cond.permute(0, 2, 3, 4, 1)

        # --- Generation ---
        B, _, Hz, Wz, Cz = x_cond.shape
        T_future = x_true_raw.shape[2]

        sample_predictions = []
        for sample_idx in range(PROBABILISTIC_SAMPLES):
            torch.manual_seed(idx * PROBABILISTIC_SAMPLES + sample_idx)

            x0_noise = torch.randn((B, T_future, Hz, Wz, Cz), device=device)
            x0_flat = x0_noise.view(B * T_future, Hz, Wz, Cz)

            def flow_dynamics(t, x_flat):
                x_flow_local = x_flat.view(B, T_future, Hz, Wz, Cz)
                t_batched = t * torch.ones(B, device=x_flow_local.device)
                with torch.no_grad():
                    v_t = loaded_model(t_batched, x_flow_local, x_cond)
                return v_t.contiguous().view(B * T_future, Hz, Wz, Cz)

            # Solve ODE (using your preferred method).
            t_span = torch.tensor([0.0, 1.0], device=x0_flat.device)
            if EULER_STEPS == 0:
                solution = odeint(
                    flow_dynamics,
                    x0_flat,
                    t_span,
                    method="adaptive_heun",
                    rtol=1e-2,
                    atol=1e-3,
                    adjoint_params=loaded_model.parameters(),
                )
            else:
                # Euler
                euler_step_size = 1.0 / float(EULER_STEPS)
                solution = odeint(
                    flow_dynamics,
                    x0_flat,
                    t_span,
                    method="euler",
                    options=dict(step_size=euler_step_size),
                    atol=1e-3,
                    rtol=1e-2,
                    adjoint_params=loaded_model.parameters(),
                )
            x_final_flat = solution[-1]
            x_pred_sample = x_final_flat.view(B, T_future, Hz, Wz, Cz)
            sample_predictions.append(x_pred_sample.unsqueeze(1))

        x_pred = torch.cat(sample_predictions, dim=1)

        # --- Postprocessing ---
        x_pred_np = x_pred.detach().cpu().numpy()
        x_true_np = x_true_raw.cpu().numpy()

        mean_val = (
            loaded_model.module.mean.cpu().numpy()
            if isinstance(loaded_model, torch.nn.DataParallel)
            else loaded_model.mean.cpu().numpy()
        )
        std_val = (
            loaded_model.module.std.cpu().numpy()
            if isinstance(loaded_model, torch.nn.DataParallel)
            else loaded_model.std.cpu().numpy()
        )
        x_pred_np = (x_pred_np * std_val + mean_val).astype(np.float32)

        if ASINH_TRANSFORM:
            x_pred_np = np.sinh(x_pred_np)

        B, S, T, H_lat, W_lat, C_lat = x_pred_np.shape
        x_pred_np_reshaped = x_pred_np.reshape(B * S * T, H_lat, W_lat, C_lat)

        x_pred_tensor = torch.from_numpy(x_pred_np_reshaped).to(device)
        x_pred_tensor = x_pred_tensor.permute(0, 3, 1, 2)

        with torch.no_grad():
            decoded_chunks = []
            bs_ae = (
                BATCH_SIZE_AUTOENCODER
                if BATCH_SIZE_AUTOENCODER is not None
                else x_pred_tensor.shape[0]
            )
            for i in range(0, x_pred_tensor.shape[0], bs_ae):
                chunk = x_pred_tensor[i : i + bs_ae]
                decoded_chunk_obj = safe_decode(ae_model, chunk)
                decoded_chunks.append(decoded_chunk_obj.sample)
            x_pred_decoded = torch.cat(decoded_chunks, dim=0)

        x_pred_unpadded = remove_padding(x_pred_decoded, padding_info)

        if NORMALIZED_AUTOENCODER:
            x_pred_unpadded = x_pred_unpadded * 57.0

        C_raw = x_true_raw.shape[1]
        H_raw = x_true_raw.shape[3]
        W_raw = x_true_raw.shape[4]
        x_pred_final = x_pred_unpadded.reshape(B, S, T, C_raw, H_raw, W_raw)
        x_pred_final = x_pred_final.permute(0, 1, 2, 4, 5, 3)

        if x_pred_final.shape[-1] == 1:
            x_pred_final = x_pred_final.squeeze(-1)

        x_true_final = x_true_np.squeeze(1)

        y_pred_batches.append(x_pred_final.detach().cpu().numpy().astype(np.float16))
        y_true_batches.append(x_true_final)

        # Accumulate metrics periodically
        if idx > 0 and (idx * BATCH_SIZE) % 400 < BATCH_SIZE:
            y_pred_array = np.concatenate(y_pred_batches, axis=0)
            y_true_array = np.concatenate(y_true_batches, axis=0)
            y_pred_array = post_process_samples(
                y_pred_array, clamp_min=12.0, clamp_max=57.0
            )

            for metrics_accumulator in metrics_accumulators:
                metrics_accumulator.update(y_true_array, y_pred_array)

            y_pred_batches, y_true_batches = [], []  # Reset

        # Animation for the first batch
        if idx == 0:
            sample_pred_anim = post_process_samples(
                x_pred_final.detach().cpu().numpy()[0], clamp_min=12.0, clamp_max=57.0
            )
            for i in range(sample_pred_anim.shape[0]):
                fig = plt.figure()
                anim = make_animation_arso(
                    sample_pred_anim[i], title=f"Prediction Sample {i}"
                )
                anim.save(
                    os.path.join(ANIMATIONS_FOLDER, f"output_test_animation_{i}.gif"),
                    writer="imagemagick",
                    fps=6,
                )
                plt.close(fig)

            fig = plt.figure()
            anim = make_animation_arso(x_true_final[0], title="Target")
            anim.save(
                os.path.join(ANIMATIONS_FOLDER, "target_test_animation.gif"),
                writer="imagemagick",
                fps=6,
            )
            plt.close(fig)

        if DEBUG_MODE and count > 10:
            print(f"{DEBUG_PRINT_PREFIX}Breaking early due to DEBUG_MODE")
            break
        count += 1

    # Save any remaining samples.
    if len(y_pred_batches) > 0:
        y_pred_array = np.concatenate(y_pred_batches, axis=0)
        y_true_array = np.concatenate(y_true_batches, axis=0)
        y_pred_array = post_process_samples(
            y_pred_array, clamp_min=12.0, clamp_max=57.0
        )
        # Update the metrics accumulator for each lead time for the remaining samples
        for lead_time, metrics_accumulator in enumerate(metrics_accumulators):
            metrics_accumulator.update(y_true_array, y_pred_array)

    del y_pred_batches
    del y_true_batches
    gc.collect()

    results = calculate_metrics(
        num_lead_times=LEAD_TIME,
        metrics_accumulators=metrics_accumulators,
        thresholds=THRESHOLDS,
    )

    crps_mean = results["crps_mean"]
    # Extract _from_mean metrics from results
    mse_from_mean_mean = results["mse_from_mean_mean"]
    apsd_from_mean_mean = results["apsd_from_mean_mean"]
    ssim_from_mean_mean = results["ssim_from_mean_mean"]
    csi_from_mean_m = results["csi_from_mean_m"]
    csi_pool_from_mean_m = results["csi_pool_from_mean_m"]
    hss_from_mean_m = results["hss_from_mean_m"]
    far_from_mean_m = results["far_from_mean_m"]
    pod_from_mean_m = results["pod_from_mean_m"]
    csi_from_mean_mean_dict = results["csi_from_mean_mean"]
    far_from_mean_mean_dict = results["far_from_mean_mean"]
    hss_from_mean_mean_dict = results["hss_from_mean_mean"]
    pod_from_mean_mean_dict = results["pod_from_mean_mean"]
    csi_pool_from_mean_mean_dict = results["csi_pool_from_mean_mean"]

    # Print _from_mean metrics
    print(f"CRPS: {crps_mean}")
    print("--- Metrics from Ensemble Mean ---")
    print(f"Mean MSE (from mean): {mse_from_mean_mean}")
    print(f"APSD (from mean): {apsd_from_mean_mean}")
    print(f"SSIM (from mean): {ssim_from_mean_mean}")
    print(f"CSI-M (from mean): {csi_from_mean_m}")
    print(f"CSI (16-pooled)-M (from mean): {csi_pool_from_mean_m}")
    print(f"HSS-M (from mean): {hss_from_mean_m}")
    print(f"FAR-M (from mean): {far_from_mean_m}")
    print(f"POD-M (from mean): {pod_from_mean_m}")
    print("CSI (from mean) per threshold:", csi_from_mean_mean_dict)
    print("FAR (from mean) per threshold:", far_from_mean_mean_dict)
    print("HSS (from mean) per threshold:", hss_from_mean_mean_dict)
    print("POD (from mean) per threshold:", pod_from_mean_mean_dict)
    print(
        f"CSI (16-pooled) (from mean) mean per threshold: {csi_pool_from_mean_mean_dict}"
    )
    # Print lead time metrics
    csi_m_from_mean_lead_time = results["csi_m_from_mean_lead_time"]
    csi_last_thresh_from_mean_lead_time = results["csi_last_thresh_from_mean_lead_time"]
    csi_pool_m_from_mean_lead_time = results["csi_pool_m_from_mean_lead_time"]
    csi_pool_last_thresh_from_mean_lead_time = results[
        "csi_pool_last_thresh_from_mean_lead_time"
    ]
    hss_m_from_mean_lead_time = results["hss_m_from_mean_lead_time"]
    far_m_from_mean_lead_time = results["far_m_from_mean_lead_time"]
    pod_m_from_mean_lead_time = results["pod_m_from_mean_lead_time"]
    print("--- Lead Time Metrics ---")

    print(f"CSI-M (from mean) by lead time: {csi_m_from_mean_lead_time}")
    print(
        f"CSI-M (219) (from mean) by lead time: {csi_last_thresh_from_mean_lead_time}"
    )
    print(
        f"CSI (16-pooled)-M (from mean) by lead time: {csi_pool_m_from_mean_lead_time}"
    )
    print(
        f"CSI (16-pooled) (219) (from mean) by lead time: {csi_pool_last_thresh_from_mean_lead_time}"
    )
    print(f"HSS-M (from mean) by lead time: {hss_m_from_mean_lead_time}")
    print(f"FAR-M (from mean) by lead time: {far_m_from_mean_lead_time}")
    print(f"POD-M (from mean) by lead time: {pod_m_from_mean_lead_time}")

    print(DEBUG_PRINT_PREFIX + "Finished testing the model")
