"""
Script for evaluating a pre-trained FlowCast diffusion model on the SEVIR test dataset.

This script performs the following steps:
1. Loads a trained FlowCast model checkpoint and its corresponding configuration.
2. Loads a pre-trained autoencoder for decoding latent representations back to pixel space.
3. Initializes the SEVIR test dataset and DataLoader.
4. Iterates through the test set, and for each sample:
    a. Encodes the input radar sequence into the latent space.
    b. Generates multiple probabilistic forecast samples using the diffusion model (DDPM or DDIM).
    c. Decodes the latent forecasts back into pixel space.
5. Accumulates metrics (MSE, CRPS, CSI, SSIM, etc.) across the test set.
6. Calculates and prints final probabilistic and deterministic metrics.
7. Generates and saves plots of metrics vs. lead time and sample forecast animations.
8. Logs results to Weights & Biases, if enabled.
"""

import gc
import sys
import os
import time
import wandb
import namegenerator
import datetime

from omegaconf import OmegaConf

sys.path.append(os.getcwd())  # Add cwd to path
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from experiments.sevir.display.cartopy import make_animation
import random
from tqdm import tqdm
from common.models.flowcast.cuboid_transformer_unet import (
    CuboidTransformerUNet,
)

from experiments.sevir.dataset.sevirfulldataset import (
    DynamicSequentialSevirDataset,
    dynamic_sequential_collate,
    post_process_samples,
)
from common.metrics.metrics_streaming_probabilistic import (
    MetricsAccumulator,
)
from common.utils.utils import calculate_metrics
import argparse
from common.diffusion.diffusion import (
    register_diffusion_buffers,
    p_sample_loop,
    ddim_sample_loop,
)


# Create argument parser
parser = argparse.ArgumentParser(description="Script for testing FlowCast model.")

# Add arguments with default values
parser.add_argument(
    "--artifacts_folder",
    type=str,
    default="saved_models/sevir/diff_flowcast",
    help="Artifacts folder to load model from",
)
parser.add_argument(
    "--config",
    type=str,
    default="experiments/sevir/runner/flowcast_diffusion/flowcast_diffusion_config.yaml",
    help="Path to the configuration file.",
)
parser.add_argument(
    "--test_data_percentage",
    type=float,
    default=1.0,
    help="Percentage of the test data to use (0.0 to 1.0).",
)
parser.add_argument(
    "--test_file",
    type=str,
    default="datasets/sevir/data/sevir_full/nowcast_testing_full.h5",
)
parser.add_argument(
    "--test_meta",
    type=str,
    default="datasets/sevir/data/sevir_full/nowcast_testing_full_META.csv",
)
# Parse arguments
args = parser.parse_args()
if not (0.0 <= args.test_data_percentage <= 1.0):
    raise ValueError("test_data_percentage must be between 0.0 and 1.0")
config = OmegaConf.load(args.config)


# --- Assign variables from config ---
# run_params
DEBUG_MODE = config.run_params.debug_mode
ENABLE_WANDB = config.run_params.enable_wandb
RUN_STRING = config.run_params.run_string

# test_params
BATCH_SIZE = config.test_params.micro_batch_size
NUM_WORKERS = config.test_params.num_workers
PROBABILISTIC_SAMPLES = config.test_params.probabilistic_samples
BATCH_SIZE_AUTOENCODER = config.test_params.batch_size_autoencoder
CARTOPY_FEATURES = config.test_params.cartopy_features

# data_params
LAG_TIME = config.data_params.lag_time
LEAD_TIME = config.data_params.lead_time
TIME_SPACING = config.data_params.time_spacing
ASINH_TRANSFORM = config.data_params.asinh_transform
BATCH_SIZE_INTERPOLATION = config.data_params.batch_size_interpolation

# autoencoder_params
PRELOAD_AE_MODEL = config.autoencoder_params.autoencoder_checkpoint
NORMALIZED_AUTOENCODER = config.autoencoder_params.normalized_autoencoder
LATENT_CHANNELS = config.autoencoder_params.latent_channels
NORM_NUM_GROUPS = config.autoencoder_params.norm_num_groups
LAYERS_PER_BLOCK = config.autoencoder_params.layers_per_block
ACT_FN = config.autoencoder_params.act_fn
BLOCK_OUT_CHANNELS = config.autoencoder_params.block_out_channels
DOWN_BLOCK_TYPES = config.autoencoder_params.down_block_types
UP_BLOCK_TYPES = config.autoencoder_params.up_block_types


RUN_ID = (
    datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    + "_"
    + RUN_STRING
    + "_"
    + namegenerator.gen()
)
ARTIFACTS_FOLDER = args.artifacts_folder
DEBUG_PRINT_PREFIX = "[DEBUG] " if DEBUG_MODE else ""
# Target locations of sample training & testing data
TEST_FILE = args.test_file
TEST_META = args.test_meta
THRESHOLDS = np.array([16, 74, 133, 160, 181, 219], dtype=np.float32)


# Load the config for the latent model
model_config = OmegaConf.to_object(config.latent_model)

# Flowcast
BASE_UNITS = model_config["base_units"]
SCALE_ALPHA = model_config["scale_alpha"]
NUM_HEADS = model_config["num_heads"]
ATTN_DROP = model_config["attn_drop"]
PROJ_DROP = model_config["proj_drop"]
FFN_DROP = model_config["ffn_drop"]
DOWNSAMPLE = model_config["downsample"]
DOWNSAMPLE_TYPE = model_config["downsample_type"]
UPSAMPLE_TYPE = model_config["upsample_type"]
UPSAMPLE_KERNEL_SIZE = model_config["upsample_kernel_size"]
DEPTH = model_config["depth"]
BLOCK_ATTN_PATTERNS = [model_config["self_pattern"]] * len(DEPTH)
NUM_GLOBAL_VECTORS = model_config["num_global_vectors"]
USE_GLOBAL_VECTOR_FFN = model_config["use_global_vector_ffn"]
USE_GLOBAL_SELF_ATTN = model_config["use_global_self_attn"]
SEPARATE_GLOBAL_QKV = model_config["separate_global_qkv"]
GLOBAL_DIM_RATIO = model_config["global_dim_ratio"]
SELF_PATTERN = model_config["self_pattern"]
FFN_ACTIVATION = model_config["ffn_activation"]
GATED_FFN = model_config["gated_ffn"]
NORM_LAYER = model_config["norm_layer"]
PADDING_TYPE = model_config["padding_type"]
CHECKPOINT_LEVEL = model_config["checkpoint_level"]
POS_EMBED_TYPE = model_config["pos_embed_type"]
USE_RELATIVE_POS = model_config["use_relative_pos"]
SELF_ATTN_USE_FINAL_PROJ = model_config["self_attn_use_final_proj"]
ATTN_LINEAR_INIT_MODE = model_config["attn_linear_init_mode"]
FFN_LINEAR_INIT_MODE = model_config["ffn_linear_init_mode"]
FFN2_LINEAR_INIT_MODE = model_config["ffn2_linear_init_mode"]
ATTN_PROJ_LINEAR_INIT_MODE = model_config["attn_proj_linear_init_mode"]
CONV_INIT_MODE = model_config["conv_init_mode"]
DOWN_UP_LINEAR_INIT_MODE = model_config["down_up_linear_init_mode"]
GLOBAL_PROJ_LINEAR_INIT_MODE = model_config["global_proj_linear_init_mode"]
NORM_INIT_MODE = model_config["norm_init_mode"]
TIME_EMBED_CHANNELS_MULT = model_config["time_embed_channels_mult"]
TIME_EMBED_USE_SCALE_SHIFT_NORM = model_config["time_embed_use_scale_shift_norm"]
TIME_EMBED_DROPOUT = model_config["time_embed_dropout"]
UNET_RES_CONNECT = model_config["unet_res_connect"]


# Diffusion Config
diffusion_config = OmegaConf.to_object(config.diffusion_params)
TIMESTEPS = diffusion_config["timesteps"]
BETA_SCHEDULE = diffusion_config["beta_schedule"]
CLIP_DENOISED = diffusion_config["clip_denoised"]
LINEAR_START = diffusion_config["linear_start"]
LINEAR_END = diffusion_config["linear_end"]
COSINE_S = diffusion_config["cosine_s"]
GIVEN_BETAS = diffusion_config["given_betas"]
ORIGINAL_ELBO_WEIGHT = diffusion_config["original_elbo_weight"]
L_SIMPLE_WEIGHT = diffusion_config["l_simple_weight"]
P2_GAMMA = diffusion_config.get("p2_gamma", 0.5)
P2_K = diffusion_config.get("p2_k", 1.0)


# Test-time DDIM config (not used here, but kept for parity and future use)
USE_DDIM = config.test_params.get("use_ddim", False)
DDIM_NUM_STEPS = config.test_params.get("ddim_num_steps", 50)
DDIM_ETA = config.test_params.get("ddim_eta", 1.0)

# Autoencoder - These are now read from the config file, so direct assignment is removed.

print(f"{DEBUG_PRINT_PREFIX}Debug Mode: {DEBUG_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Testing File: {TEST_FILE}")
print(f"{DEBUG_PRINT_PREFIX}Testing Meta: {TEST_META}")
print(f"{DEBUG_PRINT_PREFIX}Normalized Autoencoder: {NORMALIZED_AUTOENCODER}")
print(f"{DEBUG_PRINT_PREFIX}Batch Size: {BATCH_SIZE}")
print(f"{DEBUG_PRINT_PREFIX}Number of Workers: {NUM_WORKERS}")
print(f"{DEBUG_PRINT_PREFIX}Lag Time: {LAG_TIME}")
print(f"{DEBUG_PRINT_PREFIX}Lead Time: {LEAD_TIME}")
print(f"{DEBUG_PRINT_PREFIX}Time Spacing: {TIME_SPACING}")
print(f"{DEBUG_PRINT_PREFIX}Thresholds: {THRESHOLDS}")
print(f"{DEBUG_PRINT_PREFIX}Batch Size Interpolation: {BATCH_SIZE_INTERPOLATION}")
print(f"{DEBUG_PRINT_PREFIX}Probabilistic Samples: {PROBABILISTIC_SAMPLES}")
print(f"{DEBUG_PRINT_PREFIX}ASIHN Transform: {ASINH_TRANSFORM}")
print(f"{DEBUG_PRINT_PREFIX}Preload AE Model: {PRELOAD_AE_MODEL}")
print(f"{DEBUG_PRINT_PREFIX}Latent Channels: {LATENT_CHANNELS}")
print(f"{DEBUG_PRINT_PREFIX}Norm Num Groups: {NORM_NUM_GROUPS}")
print(f"{DEBUG_PRINT_PREFIX}Layers Per Block: {LAYERS_PER_BLOCK}")
print(f"{DEBUG_PRINT_PREFIX}Activation Function: {ACT_FN}")
print(f"{DEBUG_PRINT_PREFIX}Block Out Channels: {BLOCK_OUT_CHANNELS}")
print(f"{DEBUG_PRINT_PREFIX}Down Block Types: {DOWN_BLOCK_TYPES}")
print(f"{DEBUG_PRINT_PREFIX}Up Block Types: {UP_BLOCK_TYPES}")
print(f"--------- {DEBUG_PRINT_PREFIX}Flowcast Config ---------")
print(f"{DEBUG_PRINT_PREFIX}Base Units: {BASE_UNITS}")
print(f"{DEBUG_PRINT_PREFIX}Scale Alpha: {SCALE_ALPHA}")
print(f"{DEBUG_PRINT_PREFIX}Depth: {DEPTH}")
print(f"{DEBUG_PRINT_PREFIX}Block Attn Patterns: {BLOCK_ATTN_PATTERNS}")

print(f"{DEBUG_PRINT_PREFIX}Downsample: {DOWNSAMPLE}")
print(f"{DEBUG_PRINT_PREFIX}Downsample Type: {DOWNSAMPLE_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Upsample Type: {UPSAMPLE_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Num Global Vectors: {NUM_GLOBAL_VECTORS}")
print(f"{DEBUG_PRINT_PREFIX}ATTN_PROJ_LINEAR_INIT_MODE: {ATTN_PROJ_LINEAR_INIT_MODE}")
print(
    f"{DEBUG_PRINT_PREFIX}Global Proj Linear Init Mode: {GLOBAL_PROJ_LINEAR_INIT_MODE}"
)
print(f"{DEBUG_PRINT_PREFIX}Use Global Vector FFN: {USE_GLOBAL_VECTOR_FFN}")
print(f"{DEBUG_PRINT_PREFIX}Use Global Self Attn: {USE_GLOBAL_SELF_ATTN}")
print(f"{DEBUG_PRINT_PREFIX}Separate Global QKV: {SEPARATE_GLOBAL_QKV}")
print(f"{DEBUG_PRINT_PREFIX}Global Dim Ratio: {GLOBAL_DIM_RATIO}")
print(f"{DEBUG_PRINT_PREFIX}Self Pattern: {SELF_PATTERN}")
print(f"{DEBUG_PRINT_PREFIX}Attn Drop: {ATTN_DROP}")
print(f"{DEBUG_PRINT_PREFIX}Proj Drop: {PROJ_DROP}")
print(f"{DEBUG_PRINT_PREFIX}FFN Drop: {FFN_DROP}")
print(f"{DEBUG_PRINT_PREFIX}Num Heads: {NUM_HEADS}")
print(f"{DEBUG_PRINT_PREFIX}FFN Activation: {FFN_ACTIVATION}")
print(f"{DEBUG_PRINT_PREFIX}Gated FFN: {GATED_FFN}")
print(f"{DEBUG_PRINT_PREFIX}Norm Layer: {NORM_LAYER}")
print(f"{DEBUG_PRINT_PREFIX}Padding Type: {PADDING_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Pos Embed Type: {POS_EMBED_TYPE}")
print(f"{DEBUG_PRINT_PREFIX}Use Relative Pos: {USE_RELATIVE_POS}")
print(f"{DEBUG_PRINT_PREFIX}Self Attn Use Final Proj: {SELF_ATTN_USE_FINAL_PROJ}")
print(f"{DEBUG_PRINT_PREFIX}Checkpoint Level: {CHECKPOINT_LEVEL}")
print(f"{DEBUG_PRINT_PREFIX}Attn Linear Init Mode: {ATTN_LINEAR_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}FFN Linear Init Mode: {FFN_LINEAR_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Conv Init Mode: {CONV_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Down Up Linear Init Mode: {DOWN_UP_LINEAR_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Norm Init Mode: {NORM_INIT_MODE}")
print(f"{DEBUG_PRINT_PREFIX}Time Embed Channels Mult: {TIME_EMBED_CHANNELS_MULT}")
print(
    f"{DEBUG_PRINT_PREFIX}Time Embed Use Scale Shift Norm: {TIME_EMBED_USE_SCALE_SHIFT_NORM}"
)
print(f"{DEBUG_PRINT_PREFIX}Time Embed Dropout: {TIME_EMBED_DROPOUT}")
print(f"{DEBUG_PRINT_PREFIX}UNET Res Connect: {UNET_RES_CONNECT}")
print(f"{DEBUG_PRINT_PREFIX}Batch Size Autoencoder: {BATCH_SIZE_AUTOENCODER}")

print(f"--------- {DEBUG_PRINT_PREFIX}Diffusion Config ---------")
print(f"{DEBUG_PRINT_PREFIX}Timesteps: {TIMESTEPS}")
print(f"{DEBUG_PRINT_PREFIX}Beta Schedule: {BETA_SCHEDULE}")
print(f"{DEBUG_PRINT_PREFIX}Clip Denoised: {CLIP_DENOISED}")
print(f"{DEBUG_PRINT_PREFIX}Linear Start: {LINEAR_START}")
print(f"{DEBUG_PRINT_PREFIX}Linear End: {LINEAR_END}")
print(f"{DEBUG_PRINT_PREFIX}Cosine S: {COSINE_S}")
print(f"{DEBUG_PRINT_PREFIX}Given Betas: {GIVEN_BETAS}")
print(f"{DEBUG_PRINT_PREFIX}Original Elbo Weight: {ORIGINAL_ELBO_WEIGHT}")
print(f"{DEBUG_PRINT_PREFIX}L Simple Weight: {L_SIMPLE_WEIGHT}")
print(f"{DEBUG_PRINT_PREFIX}Use DDIM: {USE_DDIM}")
print(f"{DEBUG_PRINT_PREFIX}DDIM Num Steps: {DDIM_NUM_STEPS}")
print(f"{DEBUG_PRINT_PREFIX}DDIM Eta: {DDIM_ETA}")
print(f"------------------------------------------------------------")

# Plots inside the artifacts folder
PLOTS_FOLDER = ARTIFACTS_FOLDER + "/plots"
os.makedirs(PLOTS_FOLDER, exist_ok=True)
# Make Animations folder inside the plots folder
ANIMATIONS_FOLDER = PLOTS_FOLDER + "/animations"
os.makedirs(ANIMATIONS_FOLDER, exist_ok=True)
# Make Metrics folder inside the plots folder
METRICS_FOLDER = PLOTS_FOLDER + "/metrics"
os.makedirs(METRICS_FOLDER, exist_ok=True)


# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{DEBUG_PRINT_PREFIX}Using device: {device}")
if device.type == "cpu":
    print(DEBUG_PRINT_PREFIX + "CPU is used")
else:
    print(f"{DEBUG_PRINT_PREFIX}Number of GPUs available: {torch.cuda.device_count()}")
    if torch.cuda.device_count() > 1:
        print(f"{DEBUG_PRINT_PREFIX}Using {torch.cuda.device_count()} GPUs!")

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

# Create a directory for saving models inside the artifacts folder
MODEL_SAVE_DIR = ARTIFACTS_FOLDER + "/models"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, "early_stopping_model" + ".pt")


# Add helper functions after the imports
def safe_encode(model, x):
    """Encodes data using the autoencoder, handling DataParallel wrappers."""
    if isinstance(model, torch.nn.DataParallel):
        return model.module.encode(x)
    return model.encode(x)


def safe_decode(model, x):
    """Decodes data using the autoencoder, handling DataParallel wrappers."""
    if isinstance(model, torch.nn.DataParallel):
        return model.module.decode(x)
    return model.decode(x)


if ENABLE_WANDB:
    wandb.init(
        project="sevir-nowcasting-testing-cfm",
        name=RUN_ID,
        config={
            "batch_size": BATCH_SIZE,
            "num_workers": NUM_WORKERS,
            "lag_time": LAG_TIME,
            "lead_time": LEAD_TIME,
            "time_spacing": TIME_SPACING,
            "probabilistic_samples": PROBABILISTIC_SAMPLES,
            "model": "flowcast",
            "model_save_path": MODEL_SAVE_PATH,
        },
    )

PRELOAD_MODEL = MODEL_SAVE_PATH if os.path.exists(MODEL_SAVE_PATH) else None
if PRELOAD_MODEL is None:
    raise FileNotFoundError(f"Model not found at {MODEL_SAVE_PATH}")
else:
    print(f"{DEBUG_PRINT_PREFIX}Model found at {MODEL_SAVE_PATH}")

    full_test_dataset = DynamicSequentialSevirDataset(
        meta_csv=TEST_META,
        data_file=TEST_FILE,
        data_type="vil",
        raw_seq_len=49,
        lag_time=LAG_TIME,
        lead_time=LEAD_TIME,
        time_spacing=TIME_SPACING,
        stride=12,
        channel_last=False,
        debug_mode=DEBUG_MODE,
        log_transform=False,
    )

    if args.test_data_percentage < 1.0:
        num_samples = int(len(full_test_dataset) * args.test_data_percentage)
        test_dataset = Subset(full_test_dataset, range(num_samples))
    else:
        test_dataset = full_test_dataset

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=dynamic_sequential_collate,
        num_workers=NUM_WORKERS if not DEBUG_MODE else 0,
        pin_memory=True if not DEBUG_MODE else False,
    )
    # Testing the Model

    if not os.path.exists(PRELOAD_AE_MODEL):
        raise FileNotFoundError(f"Model not found at {PRELOAD_AE_MODEL}")

    from diffusers.models.autoencoders import AutoencoderKL

    ae_model = AutoencoderKL(
        in_channels=1,
        out_channels=1,
        down_block_types=DOWN_BLOCK_TYPES,
        up_block_types=UP_BLOCK_TYPES,
        block_out_channels=BLOCK_OUT_CHANNELS,
        act_fn=ACT_FN,
        latent_channels=LATENT_CHANNELS,
        norm_num_groups=NORM_NUM_GROUPS,
        layers_per_block=LAYERS_PER_BLOCK,
    )
    if DEBUG_MODE:
        checkpoint = torch.load(PRELOAD_AE_MODEL, map_location="cpu")
        ae_model.load_state_dict(checkpoint["model_state_dict"])
        ae_model.eval()
    else:
        checkpoint = torch.load(PRELOAD_AE_MODEL, map_location=device)
        ae_model.load_state_dict(checkpoint["model_state_dict"])
        ae_model = ae_model.to(device)
        if torch.cuda.device_count() > 1:
            ae_model = torch.nn.DataParallel(ae_model)
        ae_model.eval()

    # Check the shape of the first batch
    input_shape = None
    output_shape = None
    for batch in test_loader:
        inputs, outputs, metadata = batch
        # Encode but first downsample
        inputs_decoded = inputs[:, :, 0, :, :].to(device)
        # Directly use ae_model.encode, DataParallel will handle it
        if DEBUG_MODE:
            ae_model.to(device)
            if torch.cuda.device_count() > 1:
                ae_model = torch.nn.DataParallel(ae_model)
            ae_model.eval()
        encoded_obj = safe_encode(ae_model, inputs_decoded)
        if DEBUG_MODE:
            if isinstance(ae_model, torch.nn.DataParallel):
                ae_model = ae_model.module
            ae_model.to("cpu")
            torch.cuda.empty_cache()
        inputs_decoded = encoded_obj.latent_dist.mode()

        print(f"Inputs shape: {inputs.shape}")
        print(f"Inputs Decoded shape: {inputs_decoded.shape}")
        print(f"Outputs shape: {outputs.shape}")
        input_shape = inputs.shape
        inputs_decoded_shape = inputs_decoded.shape
        output_shape = outputs.shape
        break
    # Load the best model
    if DEBUG_MODE:
        checkpoint = torch.load(MODEL_SAVE_PATH, map_location="cpu", weights_only=False)
    else:
        checkpoint = torch.load(
            MODEL_SAVE_PATH, weights_only=False, map_location=device
        )
    checkpoint_mean = checkpoint["mean"]
    checkpoint_std = checkpoint["std"]
    checkpoint_model_state_dict = checkpoint["model_state_dict"]

    # Create the model
    IN_TIMESTEPS = input_shape[2]  # Condition on the past frames and the noise
    OUTPUT_TIMESTEPS = output_shape[2]  # Number of output channels

    # Input Shape for the Flowcast model expects (T_in+T_out), H_z, W_z, C_z (full noise)
    input_shape_flowcast = (
        IN_TIMESTEPS,
        inputs_decoded_shape[2],
        inputs_decoded_shape[3],
        inputs_decoded_shape[1],
    )
    output_shape_flowcast = (
        OUTPUT_TIMESTEPS,
        inputs_decoded_shape[2],
        inputs_decoded_shape[3],
        inputs_decoded_shape[1],
    )
    loaded_model = CuboidTransformerUNet(
        input_shape=input_shape_flowcast,
        target_shape=output_shape_flowcast,  # Exclude first dimension (batch size)
        base_units=BASE_UNITS,
        block_units=None,  # multiply by 2 when downsampling in each layer
        scale_alpha=SCALE_ALPHA,
        num_heads=NUM_HEADS,
        attn_drop=ATTN_DROP,
        proj_drop=PROJ_DROP,
        ffn_drop=FFN_DROP,
        downsample=DOWNSAMPLE,
        downsample_type=DOWNSAMPLE_TYPE,
        upsample_type=UPSAMPLE_TYPE,
        upsample_kernel_size=UPSAMPLE_KERNEL_SIZE,
        depth=DEPTH,
        block_attn_patterns=BLOCK_ATTN_PATTERNS,
        # global vectors
        num_global_vectors=NUM_GLOBAL_VECTORS,
        use_global_vector_ffn=USE_GLOBAL_VECTOR_FFN,
        use_global_self_attn=USE_GLOBAL_SELF_ATTN,
        separate_global_qkv=SEPARATE_GLOBAL_QKV,
        global_dim_ratio=GLOBAL_DIM_RATIO,
        # misc
        ffn_activation=FFN_ACTIVATION,
        gated_ffn=GATED_FFN,
        norm_layer=NORM_LAYER,
        padding_type=PADDING_TYPE,
        checkpoint_level=CHECKPOINT_LEVEL,
        pos_embed_type=POS_EMBED_TYPE,
        use_relative_pos=USE_RELATIVE_POS,
        self_attn_use_final_proj=SELF_ATTN_USE_FINAL_PROJ,
        # initialization
        attn_linear_init_mode=ATTN_LINEAR_INIT_MODE,
        ffn_linear_init_mode=FFN_LINEAR_INIT_MODE,
        ffn2_linear_init_mode=FFN2_LINEAR_INIT_MODE,
        attn_proj_linear_init_mode=ATTN_PROJ_LINEAR_INIT_MODE,
        conv_init_mode=CONV_INIT_MODE,
        down_linear_init_mode=DOWN_UP_LINEAR_INIT_MODE,
        up_linear_init_mode=DOWN_UP_LINEAR_INIT_MODE,
        global_proj_linear_init_mode=GLOBAL_PROJ_LINEAR_INIT_MODE,
        norm_init_mode=NORM_INIT_MODE,
        # timestep embedding for diffusion
        time_embed_channels_mult=TIME_EMBED_CHANNELS_MULT,
        time_embed_use_scale_shift_norm=TIME_EMBED_USE_SCALE_SHIFT_NORM,
        time_embed_dropout=TIME_EMBED_DROPOUT,
        unet_res_connect=UNET_RES_CONNECT,
        mean=checkpoint_mean,
        std=checkpoint_std,
    )
    loaded_model.load_state_dict(checkpoint_model_state_dict)
    if DEBUG_MODE:
        loaded_model.eval()
    else:
        loaded_model = loaded_model.to(device)
        if torch.cuda.device_count() > 1:
            loaded_model = torch.nn.DataParallel(loaded_model)
        loaded_model.eval()

    # Create a metrics accumulator for each lead time
    metrics_accumulators = [
        MetricsAccumulator(
            lead_time=lead_time,
            thresholds=THRESHOLDS,
            pool_size=16,
            compute_mse=True,
            compute_threshold=True,
            compute_apsd=False,
            compute_ssim=True,
            ssim_data_range=255.0,
            device=device,
        )
        for lead_time in range(LEAD_TIME)
    ]

    # --- Diffusion Buffers ---
    diffusion_buffers = register_diffusion_buffers(
        beta_schedule=BETA_SCHEDULE,
        timesteps=TIMESTEPS,
        linear_start=LINEAR_START,
        linear_end=LINEAR_END,
        cosine_s=COSINE_S,
        given_betas=GIVEN_BETAS,
        device=device,
    )

    test_bar = tqdm(test_loader, desc="Testing Model")
    count = 0
    y_pred = []
    y_true = []
    total_prediction_time = 0.0
    total_samples_processed = 0
    for idx, batch in enumerate(test_bar):
        x_cond, x_true, metadata = batch

        # Assume x_cond has shape [B, C, T, H, W] (e.g. [1, 1, 13, 128, 128])
        B, C, T_in, H, W = x_cond.shape

        # Permute to bring the time dimension next to the batch dimension: [B, T, C, H, W]
        # Then flatten the batch and frame dimensions to process each frame individually: [B*T, C, H, W]
        x_cond = x_cond.permute(0, 2, 1, 3, 4).reshape(B * T_in, C, H, W)

        # Encode each frame individually
        with torch.no_grad():
            x_cond = x_cond.to(device)

            if DEBUG_MODE:
                ae_model.to(device)
                if torch.cuda.device_count() > 1:
                    ae_model = torch.nn.DataParallel(ae_model)
                ae_model.eval()

            if NORMALIZED_AUTOENCODER:
                x_cond = x_cond / 255.0
            # Directly use ae_model.encode, DataParallel will handle it
            encoded_obj = safe_encode(ae_model, x_cond)
            x_cond = encoded_obj.latent_dist.mode()
            if DEBUG_MODE:
                if isinstance(ae_model, torch.nn.DataParallel):
                    ae_model = ae_model.module
                ae_model.to("cpu")
                torch.cuda.empty_cache()

            # Apply asinh transform if enabled
            if ASINH_TRANSFORM:
                x_cond = torch.asinh(x_cond).detach()

        # After encoding, suppose the encoded tensor has shape [B*T, latent_channels, latent_H, latent_W].
        # Reshape it back to [B, T, latent_channels, latent_H, latent_W] and then permute to [B, latent_channels, T, latent_H, latent_W]
        latent_channels, latent_H, latent_W = (
            x_cond.shape[1],
            x_cond.shape[2],
            x_cond.shape[3],
        )
        x_cond = x_cond.reshape(B, T_in, latent_channels, latent_H, latent_W).permute(
            0, 2, 1, 3, 4
        )

        # 2) normalize condition
        if DEBUG_MODE:
            loaded_model.to(device)
            if torch.cuda.device_count() > 1:
                loaded_model = torch.nn.DataParallel(loaded_model)
            loaded_model.eval()
        x_cond = (
            loaded_model.module.normalize(x_cond)
            if isinstance(loaded_model, torch.nn.DataParallel)
            else loaded_model.normalize(x_cond)
        )

        # Rearrange x_cond to (B, Tin, H, W, C)
        x_cond = x_cond.permute(0, 2, 3, 4, 1)

        B, Tin, Hz, Wz, Cz = x_cond.shape

        # For each batch we want to generate N_SAMPLES forecasts.
        x_true = x_true.squeeze(1)  # remove channel dimension
        T_future = x_true.shape[1]
        sample_predictions = []  # To collect N_SAMPLES for this batch

        # Generate an x_true_downsampled_example with zeros with the same shape as x_cond except for the time dimension which has T_future
        x_true_downsampled_example = torch.zeros(
            (B, T_future, Hz, Wz, Cz),
            device=device,
        )

        # Loop over sample indices.
        start_time = time.time()
        for sample_idx in range(PROBABILISTIC_SAMPLES):
            # Set a different seed for each sample (e.g., based on global idx and sample_idx)
            torch.manual_seed(idx * PROBABILISTIC_SAMPLES + sample_idx)

            if USE_DDIM:
                x_pred_sample = ddim_sample_loop(
                    loaded_model,
                    shape=x_true_downsampled_example.shape,
                    cond=x_cond,
                    diffusion_buffers=diffusion_buffers,
                    device=device,
                    total_timesteps=TIMESTEPS,
                    ddim_num_steps=DDIM_NUM_STEPS,
                    eta=DDIM_ETA,
                    clip_denoised=CLIP_DENOISED,
                )
            else:
                x_pred_sample = p_sample_loop(
                    loaded_model,
                    shape=x_true_downsampled_example.shape,
                    cond=x_cond,
                    diffusion_buffers=diffusion_buffers,
                    device=device,
                    timesteps=TIMESTEPS,
                    clip_denoised=CLIP_DENOISED,
                )

            sample_predictions.append(x_pred_sample.unsqueeze(1))  # add sample dim

        # Concatenate the N_SAMPLES along dimension 1.
        # Now sample_predictions has shape (B, N_SAMPLES, T_future, H, W)
        x_pred = torch.cat(sample_predictions, dim=1)

        end_time = time.time()
        total_prediction_time += end_time - start_time
        total_samples_processed += B

        if DEBUG_MODE:
            if isinstance(loaded_model, torch.nn.DataParallel):
                loaded_model = loaded_model.module
            loaded_model.to("cpu")
            torch.cuda.empty_cache()
        # Move predictions and ground truth to CPU and detach.
        x_pred = x_pred.cpu().detach().numpy()
        x_true_np = x_true.cpu().numpy()

        # Denormalize predictions.
        if DEBUG_MODE:
            mean_val = loaded_model.mean.item()
            std_val = loaded_model.std.item()
        else:
            mean_val = (
                loaded_model.module.mean.item()
                if isinstance(loaded_model, torch.nn.DataParallel)
                else loaded_model.mean.item()
            )
            std_val = (
                loaded_model.module.std.item()
                if isinstance(loaded_model, torch.nn.DataParallel)
                else loaded_model.std.item()
            )
        x_pred = (x_pred * std_val + mean_val).astype(np.float32)

        # Reverse asinh transform if necessary.
        if ASINH_TRANSFORM:
            x_pred = np.sinh(x_pred)

        # Assume x_pred is a NumPy array with shape [B, S, T, H, W, C]
        B, S, T, H, W, C = x_pred.shape

        # Merge batch, sample, and time dimensions: [B*S*T, H, W, C]
        x_pred = x_pred.reshape(B * S * T, H, W, C)

        # Convert to a torch tensor if it is not already one
        if isinstance(x_pred, np.ndarray):
            x_pred = torch.from_numpy(x_pred).to(device)

        # Permute to bring channels to the first dimension: [B*S*T, C, H, W]
        x_pred = x_pred.permute(0, 3, 1, 2)

        # Decode each frame individually
        if DEBUG_MODE:
            ae_model.to(device)
            if torch.cuda.device_count() > 1:
                ae_model = torch.nn.DataParallel(ae_model)
            ae_model.eval()
        with torch.no_grad():
            if BATCH_SIZE_AUTOENCODER is not None:
                # Inputs has shape [B*S*T, C, H, W]
                # Encode the input frames with a limit of BATCH_SIZE_AUTOENCODER if exists, else use the whole batch
                encoded_chunks = []
                for i in range(0, x_pred.shape[0], BATCH_SIZE_AUTOENCODER):
                    chunk = x_pred[i : i + BATCH_SIZE_AUTOENCODER]
                    # Directly use ae_model.decode, DataParallel will handle it
                    decoded_chunk_obj = safe_decode(ae_model, chunk)
                    final_decoded_chunk = decoded_chunk_obj.sample
                    encoded_chunks.append(
                        final_decoded_chunk
                    )  # Add the final decoded chunk

                # Concatenate the chunks
                x_pred = torch.cat(encoded_chunks, dim=0)
            else:
                # Directly use ae_model.decode, DataParallel will handle it
                decoded_obj_fallback = safe_decode(ae_model, x_pred)
                x_pred = decoded_obj_fallback.sample

        if DEBUG_MODE:
            if isinstance(ae_model, torch.nn.DataParallel):
                ae_model = ae_model.module
            ae_model.to("cpu")
            torch.cuda.empty_cache()
        if NORMALIZED_AUTOENCODER:
            x_pred = x_pred * 255.0
        # Assume the decoder returns a tensor of shape [B*S*T, new_channels, new_H, new_W]
        new_channels, new_H, new_W = x_pred.shape[1], x_pred.shape[2], x_pred.shape[3]

        # Reshape back to separate batch, samples, and time: [B, S, T, new_channels, new_H, new_W]
        x_pred = x_pred.reshape(B, S, T, new_channels, new_H, new_W)

        #  Put channels last
        x_pred = x_pred.permute(0, 1, 2, 4, 5, 3)

        # Remove channel dimension if it is 1
        if x_pred.shape[-1] == 1:
            x_pred = x_pred.squeeze(-1)

        # Convert to float16
        x_pred = x_pred.cpu().detach().numpy().astype(np.float16)

        # Append current batch results.
        y_pred.append(x_pred)  # shape: (B, N_SAMPLES, T_future, H, W)
        y_true.append(x_true_np)  # shape: (B, T_future, H, W)

        # Accumulate every 200 batches, reseting the arrays y_pred and y_true
        if idx % int((400 / BATCH_SIZE) / PROBABILISTIC_SAMPLES) == 0 and idx > 0:

            y_pred_array = np.concatenate(y_pred, axis=0)
            y_pred_array = post_process_samples(
                y_pred_array, clamp_min=0.0, clamp_max=255.0
            )
            y_true_array = np.concatenate(y_true, axis=0)

            # For each lead time, update the metrics accumulator
            for lead_time, metrics_accumulator in enumerate(metrics_accumulators):
                metrics_accumulator.update(y_true_array, y_pred_array)

            batch_size_y_true = y_pred_array.shape[0]

            # Reset lists.
            y_pred = []
            y_true = []

            # Calculate partial metrics and print them.
            results = calculate_metrics(
                num_lead_times=LEAD_TIME,
                metrics_accumulators=metrics_accumulators,
                thresholds=THRESHOLDS,
            )

            # Log the partial results to wandb
            if ENABLE_WANDB:
                global_step = idx * batch_size_y_true
                wandb.log(
                    {
                        "partial_mse": results["mse_mean"],
                        "partial_apsd": results["apsd_mean"],
                        "partial_crps": results["crps_mean"],
                        "partial_csi_m": results["csi_m"],
                        "partial_ssim": results["ssim_mean"],
                        "partial_csi_pool_m": results["csi_pool_m"],
                        "partial_hss_m": results["hss_m"],
                        "partial_far_m": results["far_m"],
                        "partial_pod_m": results["pod_m"],
                    },
                    step=global_step,
                )

        # Optionally, plot animations for the first batch.
        if idx == 0:
            sample_pred = x_pred[0]  # shape: (N_SAMPLES, T, H, W)
            sample_pred = post_process_samples(
                sample_pred, clamp_min=0.0, clamp_max=255.0
            )
            # For plotting, plot all samples
            for i in range(sample_pred.shape[0]):
                sample_pred_plot = sample_pred[i]
                fig1 = plt.figure()
                anim = make_animation(
                    sample_pred_plot,
                    metadata[0],
                    title="Outputs",
                    fig=fig1,
                    cartopy_features=CARTOPY_FEATURES,
                )
                anim.save(
                    os.path.join(
                        PLOTS_FOLDER, "animations", f"output_test_animation{i}.gif"
                    ),
                    writer="imagemagick",
                    fps=6,
                )
                plt.close(fig1)

            fig2 = plt.figure()
            anim = make_animation(
                x_true_np[0],
                metadata[0],
                title="Target",
                fig=fig2,
                cartopy_features=CARTOPY_FEATURES,
            )
            anim.save(
                os.path.join(PLOTS_FOLDER, "animations", "target_test_animation.gif"),
                writer="imagemagick",
                fps=6,
            )
            plt.close(fig2)

        count += 1
        if DEBUG_MODE and count > 10:
            print(f"{DEBUG_PRINT_PREFIX}Breaking early due to DEBUG_MODE")
            break

    # Save any remaining samples.
    if len(y_pred) > 0:
        y_pred_array = np.concatenate(y_pred, axis=0)
        y_pred_array = post_process_samples(
            y_pred_array, clamp_min=0.0, clamp_max=255.0
        )
        y_true_array = np.concatenate(y_true, axis=0)
        # Update the metrics accumulator for each lead time for the remaining samples
        for lead_time, metrics_accumulator in enumerate(metrics_accumulators):
            metrics_accumulator.update(y_true_array, y_pred_array)

    del y_pred
    del y_true
    gc.collect()

    results = calculate_metrics(
        num_lead_times=LEAD_TIME,
        metrics_accumulators=metrics_accumulators,
        thresholds=THRESHOLDS,
    )

    crps_mean = results["crps_mean"]

    # Print the results.
    if total_samples_processed > 0:
        average_time_per_prediction = total_prediction_time / total_samples_processed
        print(
            f"Average time per ensemble prediction: {average_time_per_prediction:.4f} seconds"
        )
    print(f"CRPS: {crps_mean}")

    # Extract _from_mean metrics from results
    mse_from_mean_mean = results["mse_from_mean_mean"]
    apsd_from_mean_mean = results["apsd_from_mean_mean"]
    ssim_from_mean_mean = results["ssim_from_mean_mean"]
    csi_from_mean_m = results["csi_from_mean_m"]
    csi_pool_from_mean_m = results["csi_pool_from_mean_m"]
    hss_from_mean_m = results["hss_from_mean_m"]
    far_from_mean_m = results["far_from_mean_m"]
    pod_from_mean_m = results["pod_from_mean_m"]
    csi_from_mean_mean_dict = results["csi_from_mean_mean"]
    far_from_mean_mean_dict = results["far_from_mean_mean"]
    hss_from_mean_mean_dict = results["hss_from_mean_mean"]
    pod_from_mean_mean_dict = results["pod_from_mean_mean"]
    csi_pool_from_mean_mean_dict = results["csi_pool_from_mean_mean"]

    # Print _from_mean metrics
    print("--- Metrics from Ensemble Mean ---")
    print(f"Mean MSE (from mean): {mse_from_mean_mean}")
    print(f"APSD (from mean): {apsd_from_mean_mean}")
    print(f"SSIM (from mean): {ssim_from_mean_mean}")
    print(f"CSI-M (from mean): {csi_from_mean_m}")
    print(f"CSI (16-pooled)-M (from mean): {csi_pool_from_mean_m}")
    print(f"HSS-M (from mean): {hss_from_mean_m}")
    print(f"FAR-M (from mean): {far_from_mean_m}")
    print(f"POD-M (from mean): {pod_from_mean_m}")
    print("CSI (from mean) per threshold:", csi_from_mean_mean_dict)
    print("FAR (from mean) per threshold:", far_from_mean_mean_dict)
    print("HSS (from mean) per threshold:", hss_from_mean_mean_dict)
    print("POD (from mean) per threshold:", pod_from_mean_mean_dict)
    print(
        f"CSI (16-pooled) (from mean) mean per lead time: {csi_pool_from_mean_mean_dict}"
    )
    # Print lead time metrics
    csi_m_from_mean_lead_time = results["csi_m_from_mean_lead_time"]
    csi_last_thresh_from_mean_lead_time = results["csi_last_thresh_from_mean_lead_time"]
    csi_pool_m_from_mean_lead_time = results["csi_pool_m_from_mean_lead_time"]
    csi_pool_last_thresh_from_mean_lead_time = results[
        "csi_pool_last_thresh_from_mean_lead_time"
    ]
    hss_m_from_mean_lead_time = results["hss_m_from_mean_lead_time"]
    far_m_from_mean_lead_time = results["far_m_from_mean_lead_time"]
    pod_m_from_mean_lead_time = results["pod_m_from_mean_lead_time"]
    print("--- Lead Time Metrics ---")

    print(f"CSI-M (from mean) by lead time: {csi_m_from_mean_lead_time}")
    print(
        f"CSI-M (219) (from mean) by lead time: {csi_last_thresh_from_mean_lead_time}"
    )
    print(
        f"CSI (16-pooled)-M (from mean) by lead time: {csi_pool_m_from_mean_lead_time}"
    )
    print(
        f"CSI (16-pooled) (219) (from mean) by lead time: {csi_pool_last_thresh_from_mean_lead_time}"
    )
    print(f"HSS-M (from mean) by lead time: {hss_m_from_mean_lead_time}")
    print(f"FAR-M (from mean) by lead time: {far_m_from_mean_lead_time}")
    print(f"POD-M (from mean) by lead time: {pod_m_from_mean_lead_time}")

    print(DEBUG_PRINT_PREFIX + "Finished testing the model")
