import gc
import logging
import os
import pickle
import time
from concurrent.futures import Future
from contextlib import nullcontext
from pathlib import Path
from random import shuffle
from subprocess import CalledProcessError
from typing import Any, Callable, Literal, Optional

import h5py
import numpy as np
import ot
import pandas as pd
import torch
import torch.distributed as dist
from einops import rearrange
from filelock import FileLock
from the_well.benchmark.metrics import (
    long_time_metrics,
    make_video,
    plot_all_time_metrics,
    validation_metric_suite,
    validation_plots,
)
from the_well.data.datamodule import AbstractDataModule
from the_well.data.datasets import WellDataset
from the_well.data.utils import flatten_field_names
from torch.amp.grad_scaler import GradScaler
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.utils.data import DataLoader

import wandb
from crps_retrofitting.metrics.crps import CRPS, WeightedCRPS
from crps_retrofitting.metrics.energy_score import ES
from crps_retrofitting.metrics.fourier import isotropic_power_spectrum
from crps_retrofitting.models.isotropic_model import IsotropicModelWithNoise
from crps_retrofitting.models.poseidon_model import PoseidonWrapper, ScOTWithNoise
from crps_retrofitting.trainer.checkpoints import CheckPointLoader
from crps_retrofitting.trainer.normalization_strat import (
    BaseRevNormalization,
    SamplewiseRevNormalization,
    normalize_target,
)

logger = logging.getLogger(__name__)


def expand_mask_to_match(mask, target):
    """Expand mask of shape B H [W D] 1 to
    broadcast with tensor of given shape B T H [W D] C"""
    T = target.shape[1]
    C = target.shape[-1]
    expansion_tuple = (
        -1,
        T,
    )
    expansion_tuple = expansion_tuple + (-1,) * (len(target.shape) - 3) + (C,)
    mask = mask.unsqueeze(1).expand(*expansion_tuple)
    return mask


def get_grad_norm_local(model) -> torch.Tensor:
    """Computes grad norm for the specific device

    From https://github.com/pytorch/pytorch/issues/88621"""
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            local_norm = torch.linalg.vector_norm(p.grad, dtype=p.dtype)
            total_norm += local_norm**2
    return total_norm**0.5


def get_grad_norm_fsdp(
    model, rank, world_size, sharding_strategy=ShardingStrategy.FULL_SHARD
) -> torch.Tensor:
    """Combines grad norm for the specific device

    From https://github.com/pytorch/pytorch/issues/88621"""
    local_norm = get_grad_norm_local(model)
    op = torch.distributed.ReduceOp.SUM
    return_norm = local_norm.clone().detach().requires_grad_(False) ** 2
    dist.all_reduce(return_norm, op=op)
    if sharding_strategy == ShardingStrategy.NO_SHARD:
        return_norm = return_norm / world_size
    return return_norm**0.5


def param_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            total_norm += p.pow(2).sum().item()
        return total_norm**0.5


class Trainer:
    grad_scaler: GradScaler

    def __init__(
        self,
        experiment_name: str,
        viz_folder: str,
        formatter: Callable,
        model: torch.nn.Module,
        datamodule: AbstractDataModule,
        revin: BaseRevNormalization,
        optimizer: torch.optim.Optimizer,
        loss_fn: Callable,
        prediction_type: str,
        # validation_suite: list,
        max_epoch: int,
        val_frequency: int,
        rollout_val_frequency: int,
        max_rollout_steps: int,
        short_validation_length: int,
        checkpointer: CheckPointLoader,
        num_time_intervals: int,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        device=torch.device("cuda"),
        device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
        sampling_rank_strategy: Literal["gpu", "node", "world"] = "node",
        reuse_batches: bool = False,
        distribution_type: str = "local",
        rank: int = 0,
        world_size: int = 1,
        enable_amp: bool = False,
        amp_type: str = "float16",  # bfloat not supported in FFT
        grad_acc_steps: int = 1,
        clip_gradient: float = 0.0,
        loss_multiplier: float = 1.0,
        minimum_context: int = 1,
        masked_loss_for_objects: bool = True,
        video_validation: bool = False,
        image_validation: bool = False,
        skip_spectral_metrics: bool = False,
        gradient_log_level: int = 0,
        log_interval: int = 10,
        wandb_logging: bool = True,
        start_epoch: int = 1,
        lr_scheduler_per_step: bool = False,
        start_val_loss: Optional[float] = None,
        enable_staged_learning: Optional[bool] = False,
        common_params_warmup_epochs: Optional[int] = 5,
        validation_ensemble_size: Optional[int] = 1,
        ensemble_sizes_to_save: Optional[list] = [],
        max_num_samples: Optional[int] = 1000000,
        max_spectral_val_samples: Optional[int] = 5,
        num_repeats: Optional[int] = 1,
        poseidon_one_sample_norm: bool = False,
        start_traj: Optional[int] = 0,
        scale_noise: bool = False,
    ):
        """
        Class in charge of the training loop. It performs train, validation and test.

        Parameters
        ----------
        experiment_name:
            The name of the training experiment to be run
        viz_folder:
            The folder where visualizations are saved
        formatter:
            Callable that initializes formatter object that maps between Well and model formats.
        model:
            PyTorch model used for training.
        datamodule:
            A datamodule that provides dataloaders for each split (train, valid, and test)
        optimizer:
            A Pytorch optimizer to perform the backprop (e.g. Adam)
        loss_fn:
            A loss function that evaluates the model predictions to be used for training
        prediction_type:
            The type of prediction to make. Options are "delta" or "full". "delta" predicts the change in the
            field from the previous timestep. "full" predicts the full field at the next timestep.
            This only affects training since validation losses are computed on reconstructed fields
            either way.
        max_epoch:
            Number of epochs to train the model.
            One epoch correspond to a full loop over the datamodule's training dataloader
        val_frequency:
            The frequency in terms of number of epochs to perform the validation
        rollout_val_frequency:
            The frequency in terms of number of epochs to perform the rollout validation
        max_rollout_steps:
            The maximum number of timesteps to rollout the model during long validation.
        num_time_intervals:
            The number of time intervals to bin the loss over for logging purposes.
        lr_scheduler:
            A Pytorch learning rate scheduler to update the learning rate during training
        device:
            A Pytorch device (e.g. "cuda" or "cpu")
        device_mesh:
            Device mesh used for distributed training
        reuse_batches:
            A boolean flag to reuse batches during training. If True, the same batch is used for
            two steps of training (separated by a new batch). If False, a new batch is used for each step.
        distribution_type:
            The type of distribution to use. Options are "local", "ddp", "fsdp", "hsdp"
        rank:
            The rank of the current GPU in the PyTorch world.
        world_size:
            The total number of GPUs in the PyTorch world
        enable_amp:
            A boolean flag to enable automatic mixed precision training
        amp_type:
            The type of automatic mixed precision to use. Options are "float16" or "bfloat16"
        grad_acc_steps:
            The number of gradient accumulation steps to perform between optimizer steps
        clip_gradient:
            The maximum gradient norm to clip to. If 0, no clipping is performed.
        loss_multiplier:
            A float to multiply the loss by before backpropagating. Useful for scaling loss for different
            datasets or loss functions.
        minimum_context:
            The minimum number of timesteps needed to evaluate the loss for a given sample
        video_validation:
            A boolean flag to enable saving rollouts to disk during validation
        image_validation:
            A boolean flag to enable saving images to disk during validation
        skip_spectral_metrics:
            A boolean flag to skip spectral metrics during validation
        gradient_log_level:
            An integer representing the level of gradient logging. 0 is no logging, 1 full synced gradient only
        log_interval:
            An integer representing how often to log training information. This results in gpu-cpu sync.
        wandb_logging:
            A boolean flag to enable logging to Weights and Biases        checkpoint_frequency:
            An integer representing after how many epochs training checkpoint is saved
        start_epoch:
            The epoch to start training from. Useful for resuming training.
        lr_scheduler_per_step:
            A boolean flag to update the learning rate after each optimizer step instead of each epoch.
        start_val_loss:
            The validation loss to start from. Useful for resuming training.
        enable_staged_learning:
            A boolean flag to enable staged learning. If True, the learning rate will be updated in stages.
        common_params_warmup_epochs:
            The number of epochs to warm up the learning rate for common parameters.
        validation_ensemble_size:
            Ensemble size at validation time.
        max_number_of_samples:
            Maximum number of samples that can be predicted at a time.
        num_repeats:
            Number of repeats to average over at each rollout step for the jittering trick.
        start_traj:
            The trajectory index to start from during validation.
        """
        self.experiment_name = experiment_name
        self.viz_folder = viz_folder
        self.wandb_logging = wandb_logging
        self.video_validation = video_validation
        self.image_validation = image_validation
        self.gradient_log_level = gradient_log_level
        self.log_interval = log_interval
        self.device = device
        self.model = model
        # Number of samples for the model - if deterministic, this should be 1
        if type(model) == IsotropicModelWithNoise or (
            hasattr(model, "sc_ot") and isinstance(model.sc_ot, ScOTWithNoise)
        ):
            self.validation_ensemble_size = validation_ensemble_size
            self.num_samples = model.num_samples
            if self.validation_ensemble_size not in ensemble_sizes_to_save:
                ensemble_sizes_to_save.append(self.validation_ensemble_size)
                self.ensemble_sizes_to_save = sorted(ensemble_sizes_to_save)
            else:
                self.ensemble_sizes_to_save = sorted(ensemble_sizes_to_save)
        else:
            self.validation_ensemble_size = 1
            self.num_samples = 1
            self.ensemble_sizes_to_save = [1]
        assert (
            max(self.ensemble_sizes_to_save) <= self.validation_ensemble_size
        ), f"Ensemble sizes to save {max(self.ensemble_sizes_to_save)} cannot be larger than validation ensemble size {self.validation_ensemble_size}"
        self.max_num_samples = max_num_samples
        self.max_spectral_val_samples = max_spectral_val_samples
        self.num_repeats = num_repeats
        self.poseidon_one_sample_norm = poseidon_one_sample_norm
        self.datamodule = datamodule
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loss_fn = loss_fn
        self.prediction_type = prediction_type
        validation_metrics_names = [
            m.__class__.__name__ for m in validation_metric_suite
        ]
        if self.loss_fn.__class__.__name__ in validation_metrics_names:
            self.validation_suite = validation_metric_suite
        else:
            self.validation_suite = validation_metric_suite + [self.loss_fn]
        validation_metrics_names = [m.__class__.__name__ for m in self.validation_suite]
        if "CRPS" not in validation_metrics_names:
            self.validation_suite.append(CRPS())
        if "WeightedCRPS" not in validation_metrics_names:
            self.validation_suite.append(WeightedCRPS())
        if "ES" not in validation_metrics_names:
            self.validation_suite.append(ES())
        print("Validation suite", self.validation_suite)
        self.skip_spectral_metrics = skip_spectral_metrics
        assert (
            max_epoch + 1 > start_epoch
        ), f"Expect to train for at least one epoch but request starting from {start_epoch} until {max_epoch} epochs."
        # These starting parameters are just for resuming runs
        self.start_epoch = start_epoch
        self.start_val_loss = start_val_loss
        self.enable_staged_learning = enable_staged_learning
        # Run logistics
        self.max_epoch = max_epoch
        self.val_frequency = val_frequency
        self.rollout_val_frequency = rollout_val_frequency
        self.max_rollout_steps = max_rollout_steps
        self.start_traj = start_traj
        self.scale_noise = scale_noise
        self.short_validation_length = short_validation_length
        self.num_time_intervals = num_time_intervals
        self.enable_amp = enable_amp
        self.grad_acc_steps = grad_acc_steps
        self.clip_gradient = clip_gradient
        self.loss_multiplier = loss_multiplier
        self.minimum_context = minimum_context
        self.lr_scheduler_per_step = lr_scheduler_per_step
        self.masked_loss_for_objects = masked_loss_for_objects
        self.amp_type = torch.bfloat16 if amp_type == "bfloat16" else torch.float16
        self.checkpointer = checkpointer
        # If local or DDP, can use standard grad scaler
        if distribution_type.upper() in ["LOCAL", "DDP"]:
            self.grad_scaler = torch.GradScaler(
                device=self.device.type, enabled=enable_amp and amp_type != "bfloat16"
            )
        # Otherwise need sharded version.
        else:
            self.grad_scaler = ShardedGradScaler(
                device=self.device.type, enabled=enable_amp and amp_type != "bfloat16"
            )

        self.device_mesh = device_mesh
        self.is_distributed = device_mesh is not None
        self.distribution_type = distribution_type
        self.rank = rank
        self.world_size = world_size
        self.sampling_rank_strategy = sampling_rank_strategy
        self.reuse_batches = reuse_batches
        # Get derived rank info about which nodes must be synced from the device mesh
        if self.device_mesh is not None and "fsdp" in self.device_mesh.mesh_dim_names:
            self.sync_group = self.device_mesh.get_group(mesh_dim="fsdp")
            self.sync_group_size = self.sync_group.size()
        else:  # Local or DDP
            self.sync_group = None
            self.sync_group_size = 1
        # Local or DDP do not need all-gather to do forward pass
        self.sync_group_rank = self.rank // self.sync_group_size
        self.rank_in_sync_group = self.rank % self.sync_group_size
        self.num_sync_groups = self.world_size // self.sync_group_size
        if (
            self.sampling_rank_strategy == "gpu"
        ):  # This means sample different dataset per GPU
            self.sampling_rank = self.rank
        elif (
            self.sampling_rank_strategy == "node"
        ):  # This means sample different dataset per node
            self.sampling_rank = self.sync_group_rank
        else:  # This means sample single dataset per step
            self.sampling_rank = 0

        self.dset_metadata = self.datamodule.train_dataset.dset_to_metadata
        self.revin = revin(self.datamodule.train_dataset, self.device)

        self.formatter_dict = {}
        # Initial formatter for each dataset - right now these are all identical
        # but we might want to differentiate them in the future.
        for dset_name, metadata in self.dset_metadata.items():
            self.formatter_dict[metadata.dataset_name] = formatter()

        # Deterministic or Probabilistic
        self.is_deterministic = not (
            isinstance(model, IsotropicModelWithNoise)
            or (hasattr(model, "sc_ot") and isinstance(model.sc_ot, ScOTWithNoise))
        )

    def _get_first_sample_idx(self, y_pred: torch.Tensor, rollout_length: int):
        # y_pred can either be
        # - [B, L, H, [W], [D], C] or
        # - [B, N_ensemble, L, H, [W], [D], C]
        if y_pred.shape[1] == rollout_length:
            return (0,), (0,)
        elif y_pred.shape[2] == rollout_length:
            return (0, 0), (0, 0)
        else:
            logger.warning(
                f"Unexpected prediction tensor shape for video creation: {y_pred.shape}"
            )
            return

    def save_model_if_necessary(
        self, epoch: int, validation_loss: float, last: bool = False
    ) -> Optional[Future]:
        """Save the model checkpoint.
        Force checkpointing if last.
        """
        torch.cuda.empty_cache()
        gc.collect()
        checkpoint_future = self.checkpointer.save_if_necessary(
            self.model, self.optimizer, validation_loss, epoch, force=last
        )
        return checkpoint_future

    def rollout_model(
        self,
        model,
        batch,
        formatter,
        train=True,
        fake_pass=False,
        train_rollout_steps: int = None,
        noise_scale: float = 1.0,
    ):
        """Rollout the model for as many steps as we have data for.

        predict_normalized: bool - If true, output normalized prediction. During one-step training,
            predict normalized values to reduce precision issues/extra FLOPs. During rollout,
            denormalize the output for loss calculation. If multiple steps used during training,
            throw error because not currently supported.
        """
        # If in train, predict the number of members set as num_samples
        # More freedom at validation time
        if self.model.training:
            self.num_samples = getattr(self.model, "num_samples", 1)
        else:
            self.num_samples = self.validation_ensemble_size

        metadata = batch["metadata"]
        batch = {
            k: (
                v.to(self.device, non_blocking=True)
                if k not in {"metadata", "boundary_conditions"}
                else v
            )
            for k, v in batch.items()
        }
        # Extract mask and move to device for loss eval
        if (
            self.masked_loss_for_objects
            and "mask" in batch["metadata"].constant_field_names[0]
        ):
            mask_index = batch["metadata"].constant_field_names[0].index("mask")
            mask = batch["constant_fields"][..., mask_index : mask_index + 1]
            mask = mask.to(self.device, dtype=torch.bool, non_blocking=True)
        else:
            mask = None

        inputs, y_ref = formatter.process_input(
            batch,
            causal_in_time=model.causal_in_time,
            predict_delta=self.prediction_type == "delta",
            train=train,
        )

        # Inputs T B C H [W D], y_ref B T H [W D] C
        # If causal, during training don't include initial context in rollout length
        B, T_in = batch["input_fields"].shape[:2]
        if model.causal_in_time:
            max_rollout_steps = self.max_rollout_steps + (T_in - 1)
        else:
            max_rollout_steps = self.max_rollout_steps
        rollout_steps = min(
            y_ref.shape[1], max_rollout_steps
        )  # Number of timesteps in target
        train_rollout_limit = T_in if (train and model.causal_in_time) else 1

        # If we explicitly ask for train_rollout_steps, we use that instead of the limit.
        if train and train_rollout_steps is not None:
            # Ensure we don't ask for more data than we have
            rollout_steps = min(y_ref.shape[1], train_rollout_steps)
        else:
            if rollout_steps > train_rollout_limit and train:
                raise ValueError(
                    "Multiple step prediction in train mode not yet supported"
                )
        y_ref = y_ref[:, :rollout_steps]
        # Create a moving batch of one step at a time
        moving_batch = batch
        y_preds = []
        # Fake pass is just a convenience for distributed validation without communcation code
        if fake_pass:
            return y_ref, y_ref
        # Rollout the model - Causal in time gets more predictions from the first step
        for i in range(train_rollout_limit - 1, rollout_steps):
            # Don't fill causal_in_time here since that only affects y_ref
            inputs, _ = formatter.process_input(moving_batch)
            inputs = list(inputs)
            with torch.no_grad():
                # For Poseidon compute the normalisation stats based on one sample only as it is Markovian
                if type(model) == PoseidonWrapper and self.poseidon_one_sample_norm:
                    normalization_stats = self.revin.compute_stats(
                        inputs[0][-1:], metadata, epsilon=1e-6
                    )
                else:
                    normalization_stats = self.revin.compute_stats(
                        inputs[0], metadata, epsilon=1e-5
                    )
            # NOTE - Currently assuming only [0] (fields) needs normalization
            normalized_inputs = inputs[:]  # Map type bugs out
            normalized_inputs[0] = self.revin.normalize_stdmean(
                normalized_inputs[0], normalization_stats
            )
            num_samples = (
                self.num_samples if i == train_rollout_limit - 1 else 1
            )  # We create the ensemble in the first rollout step and then we just propagate those samples

            # If doing deterministic predictions
            if self.is_deterministic:
                repeated_y_pred = []
                for repeat in range(self.num_repeats):
                    y_pred = model(
                        normalized_inputs[0],
                        normalized_inputs[1],
                        normalized_inputs[2].tolist(),
                        metadata=metadata,
                        num_samples=num_samples,
                    )  # [T, B, C, H, [W], [D]]
                    # During validation, don't maintain full inner predictions
                    if not train and model.causal_in_time:
                        y_pred = y_pred[-1:]  # y_pred is T first, y_ref is not
                    repeated_y_pred.append(y_pred)
                y_pred = torch.stack(repeated_y_pred, dim=0).mean(dim=0)
            # If creating probabilistic predictions
            else:
                repeated_y_pred = []
                cond_noise = (
                    torch.randn((T_in, B * self.num_samples, model.noise_dim)).to(
                        self.device
                    )
                    * noise_scale
                )
                for repeat in range(self.num_repeats):
                    y_pred = []
                    samples_generated = 0
                    # In case the number of samples we want to generate is higher than self.max_num_samples, we generate
                    # self.max_num_samples at a time and accumulate them into y_pred.
                    # NOTE : To make sure we correctly associate each batch with its ensemble members, we always split the batch
                    # in B sized chunks and then stack them back together
                    while samples_generated < self.num_samples:
                        # For the first rollout step, num_samples will be the actual ensemble size but after that
                        # it becomes 1 (because each ensemble is propagated independently). Hence we add the else
                        if i == train_rollout_limit - 1:
                            current_num_samples = min(
                                self.max_num_samples, num_samples - samples_generated
                            )
                            y_pred_temp = model(
                                normalized_inputs[0],
                                normalized_inputs[1],
                                normalized_inputs[2].tolist(),
                                metadata=metadata,
                                num_samples=current_num_samples,
                                cond_noise=cond_noise[
                                    :,
                                    samples_generated : samples_generated
                                    + B * current_num_samples,
                                ],
                            )  # [T, B*num_samples, C, H, [W], [D]]
                        else:
                            current_num_samples = min(
                                self.max_num_samples * B,
                                normalized_inputs[0].shape[1] - samples_generated,
                            )
                            y_pred_temp = model(
                                normalized_inputs[0][
                                    :,
                                    samples_generated : samples_generated
                                    + current_num_samples,
                                ],
                                normalized_inputs[1],
                                normalized_inputs[2].tolist(),
                                metadata=metadata,
                                num_samples=num_samples,
                                cond_noise=cond_noise[
                                    :,
                                    samples_generated : samples_generated
                                    + B * current_num_samples,
                                ],
                            )  # [T, B*num_samples, C, H, [W], [D]]

                        samples_generated += current_num_samples
                        # During validation, don't maintain full inner predictions
                        if not train and model.causal_in_time:
                            y_pred_temp = y_pred_temp[
                                -1:
                            ]  # y_pred is T first, y_ref is not

                        # y_pred_temp shape: [T, B * current_num_samples, C, H, [W], [D]]
                        # Target shape: [T, B, current_num_samples, C, H, [W], [D]]
                        original_shape = y_pred_temp.shape
                        T, _, C = original_shape[:3]
                        spatial_dims = original_shape[3:]  # (H, W) or (H, W, D)

                        # Reshape to separate batch and ensemble dimensions
                        if i == train_rollout_limit - 1:
                            new_shape = (T, B, current_num_samples, C) + spatial_dims
                        else:
                            new_shape = (
                                T,
                                B,
                                current_num_samples // B,
                                C,
                            ) + spatial_dims
                        if y_pred_temp.is_contiguous() == False:
                            y_pred_temp = y_pred_temp.contiguous()
                        y_pred_temp = y_pred_temp.view(new_shape)
                        y_pred.append(y_pred_temp)
                        del y_pred_temp
                        torch.cuda.empty_cache()

                    y_pred = torch.cat(y_pred, dim=2)  # [T, B, num_samples, C, H, W, D]
                    y_pred = y_pred.reshape(
                        T, B * self.num_samples, *y_pred.shape[3:]
                    )  # [T, B*num_samples, C, H, W, D]
                    repeated_y_pred.append(y_pred)

                y_pred = torch.stack(repeated_y_pred, dim=0).mean(dim=0)

            # Maintain normalised predictions
            y_pred_normalized = y_pred

            # Train used normalized values to avoid precision loss
            # Validation on the other hand, reconstructs predictions on original scale
            needs_denorm = (not train) or (i < rollout_steps - 1)

            y_pred_denormalized = None
            if needs_denorm:
                if self.prediction_type == "delta":
                    # y_pred - (T_all or T=-1 depending on causal or not), B, C, H, [W, D]. Different from y_ref
                    if i == train_rollout_limit - 1:
                        # Repeat normalization stats + inputs for each member of the ensemble in the first rollout step

                        normalization_stats.delta_std = (
                            normalization_stats.delta_std.unsqueeze(
                                2
                            )  # add ensemble dim
                            .expand(
                                -1,
                                -1,
                                num_samples,
                                *([-1] * (normalization_stats.delta_std.ndim - 2)),
                            )
                            .reshape(
                                normalization_stats.delta_std.shape[0],
                                normalization_stats.delta_std.shape[1] * num_samples,
                                *normalization_stats.delta_std.shape[2:],
                            )
                        )
                        normalization_stats.delta_mean = (
                            normalization_stats.delta_mean.unsqueeze(
                                2
                            )  # add ensemble dim
                            .expand(
                                -1,
                                -1,
                                num_samples,
                                *([-1] * (normalization_stats.delta_mean.ndim - 2)),
                            )
                            .reshape(
                                normalization_stats.delta_mean.shape[0],
                                normalization_stats.delta_mean.shape[1] * num_samples,
                                *normalization_stats.delta_mean.shape[2:],
                            )
                        )

                        inputs[0] = (
                            inputs[0]
                            .unsqueeze(2)
                            .expand(-1, -1, num_samples, *([-1] * (inputs[0].ndim - 2)))
                            .reshape(
                                inputs[0].shape[0],
                                inputs[0].shape[1] * num_samples,
                                *inputs[0].shape[2:],
                            )
                        )
                    with torch.autocast(
                        self.device.type, enabled=False, dtype=self.amp_type
                    ):
                        y_pred_denormalized = inputs[0][
                            -y_pred.shape[0] :
                        ].float() + self.revin.denormalize_delta(
                            y_pred_normalized, normalization_stats
                        )  # Unnormalize delta and add to input
                elif self.prediction_type == "full":
                    if i == train_rollout_limit - 1:
                        if isinstance(self.revin, SamplewiseRevNormalization):
                            # Repeat normalization stats for each member of the ensemble
                            normalization_stats.sample_std = (
                                normalization_stats.sample_std.repeat_interleave(
                                    num_samples, dim=1
                                )
                            )
                            normalization_stats.sample_mean = (
                                normalization_stats.sample_mean.repeat_interleave(
                                    num_samples, dim=1
                                )
                            )
                    y_pred_denormalized = self.revin.denormalize_stdmean(
                        y_pred_normalized, normalization_stats
                    )
                else:
                    raise ValueError(
                        f"Invalid prediction type {self.prediction_type}. Valid types are delta/full"
                    )
                # Process the denormalized output (Cut channels, etc.)
                y_pred_denormalized = formatter.process_output(
                    y_pred_denormalized, metadata
                )[..., : y_ref.shape[-1]]

                # Apply mask to denormalized output
                if mask is not None:
                    mask_pred_denorm = expand_mask_to_match(mask, y_pred_denormalized)
                    y_pred_denormalized.masked_fill_(mask_pred_denorm, 0)

                # Apply padding mask
                y_pred_denormalized = y_pred_denormalized.masked_fill(
                    ~batch["padded_field_mask"], 0.0
                )

            # At training time we use the normalized prediction for loss calculation
            if train:
                y_pred_loss = y_pred_normalized
                y_pred_loss = formatter.process_output(y_pred_loss, metadata)[
                    ..., : y_ref.shape[-1]
                ]
                # TODO - redo losses to accept losses since this will be more efficient there
                if mask is not None:
                    mask_pred_loss = expand_mask_to_match(mask, y_pred_loss)
                    y_pred_loss.masked_fill_(mask_pred_loss, 0)
                y_pred_loss = y_pred_loss.masked_fill(~batch["padded_field_mask"], 0.0)
            else:
                y_pred_loss = y_pred_denormalized

            # If not last step, update moving batch for autoregressive prediction
            # TODO - for anyone updating this later, it's the primary reason why
            # multiple steps isn't currently supported since we want to recompute
            # normalization stats at each step, but also want to compute loss
            # on normalized values
            if i != rollout_steps - 1:
                # If we are using a stochastic model, repeat the batch for each member of the ensemble
                # at the first step to then propagate them.
                if not self.is_deterministic and i == train_rollout_limit - 1:
                    repeats = (
                        y_pred_denormalized.shape[0]
                        // moving_batch["input_fields"].shape[0]
                    )
                    moving_batch["input_fields"] = torch.cat(
                        [
                            moving_batch["input_fields"][:, 1:].repeat_interleave(
                                repeats, dim=0
                            ),
                            y_pred_denormalized[:, -1:],
                        ],
                        dim=1,
                    )
                else:
                    moving_batch["input_fields"] = torch.cat(
                        [
                            moving_batch["input_fields"][:, 1:],
                            y_pred_denormalized[:, -1:],
                        ],
                        dim=1,
                    )
            # For causal models, we get use full predictions for the first batch and
            # incremental predictions for subsequent batches - concat 1:T to y_ref for loss eval
            # TODO - test this works - currently getting non-causal working then looping back
            if model.causal_in_time and i == train_rollout_limit - 1:
                y_preds.append(y_pred_loss)
            else:
                y_preds.append(y_pred_loss[:, -1:])
        y_pred_out = torch.cat(y_preds, dim=1)
        # Post-processing y_ref depending on train - if train, normalize y_ref before loss calc
        # If not train, we already denormalized the prediction
        if train:
            mean = (
                normalization_stats.sample_mean
                if self.prediction_type == "full"
                else normalization_stats.delta_mean
            )
            std = (
                normalization_stats.sample_std
                if self.prediction_type == "full"
                else normalization_stats.delta_std
            )
            y_ref = normalize_target(y_ref, mean, std, formatter, metadata, self.device)
        if mask is not None:
            mask_ref = expand_mask_to_match(mask, y_ref)
            y_ref.masked_fill_(mask_ref, 0)

        del moving_batch, batch, mask  # Free up batch memory when done
        return y_pred_out, y_ref

    def temporal_split_losses(
        self, loss_values, temporal_loss_intervals, loss_name, dset_name, fname="full"
    ):
        new_losses = {}
        # Average over time interval
        new_losses[f"{dset_name}/{fname}_{loss_name}_T=all"] = loss_values.mean()
        # Don't compute sublosses if we only have one interval
        if len(temporal_loss_intervals) == 2:
            return new_losses
        # Break it down by time interval
        for k in range(len(temporal_loss_intervals) - 1):
            start_ind = temporal_loss_intervals[k]
            end_ind = temporal_loss_intervals[k + 1]
            time_str = f"{start_ind}:{end_ind}"
            loss_subset = loss_values[start_ind:end_ind].mean()
            new_losses[f"{dset_name}/{fname}_{loss_name}_T={time_str}"] = loss_subset
        return new_losses

    def split_up_losses(
        self, loss_values, loss_name, dset_name, field_names
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        new_losses: dict[str, Any] = {}
        time_logs: dict[str, Any] = {}
        time_steps = loss_values.shape[0]  # we already average over batch
        num_time_intervals = min(time_steps, self.num_time_intervals)
        temporal_loss_intervals = np.linspace(0, np.log(time_steps), num_time_intervals)
        temporal_loss_intervals = [0] + [
            int(np.exp(x)) for x in temporal_loss_intervals
        ]
        # Split up losses by field
        for i, fname in enumerate(field_names):
            time_logs[f"{dset_name}/{fname}_{loss_name}_rollout"] = loss_values[
                :, i
            ].cpu()
            new_losses |= self.temporal_split_losses(
                loss_values[:, i], temporal_loss_intervals, loss_name, dset_name, fname
            )
        # Compute average over all fields
        new_losses |= self.temporal_split_losses(
            loss_values.mean(1), temporal_loss_intervals, loss_name, dset_name, "full"
        )
        time_logs[f"{dset_name}/full_{loss_name}_rollout"] = loss_values.mean(1).cpu()
        return new_losses, time_logs

    @torch.no_grad()
    def validation_loop(
        self,
        dataloaders: list[WellDataset],
        valid_or_test: str = "valid",
        full=False,
        epoch: int = 0,
    ) -> tuple[float, dict[str, Any]]:
        """Run validation by looping over the dataloader.

        Validate the same dataset over FSDP groups since they're locked
        by syncs but distribute over replication (DDP) groups.
        """
        self.model.eval()
        validation_loss = 0.0
        loss_dict: dict[str, Any] = {}
        time_logs: dict[str, Any] = {}
        plot_dicts: dict[str, Any] = {}
        metadatas = []
        # Timeouts get really annoying - barrier at start makes it slightly better
        if torch.distributed.is_initialized():
            torch.distributed.barrier()
        # Each dataset being validated gets separate loader
        for i, dataloader in enumerate(dataloaders):
            # Grab metadata for the current dataset
            assert (
                len(dataloader.dataset.sub_dsets) == 1
            ), "Only one dataset per validation dataloader"
            dataset = dataloader.dataset.sub_dsets[
                0
            ]  # There is only one dset by design
            current_metadata = dataset.metadata
            metadatas.append(current_metadata)
            dset_name = current_metadata.dataset_name
            field_names = flatten_field_names(current_metadata, include_constants=False)
            inner_loss_dict = {}
            rank_assignment = (
                i % self.num_sync_groups
            )  # Use same data across one sync group
            # Only print if we're doing something on this node
            if rank_assignment == self.sync_group_rank:
                logger.info(
                    f"Validating dataset {dataset.metadata.dataset_name} with full_trajectory_mode={dataset.full_trajectory_mode} on rank {self.rank}"
                )
            count = 0
            denom = (
                len(dataloader)
                if full
                else min(len(dataloader), self.short_validation_length)
            )
            # Instantiate lines to save and directory if doing full validation
            if full:
                save_dir = f"{self.viz_folder}/{current_metadata.dataset_name}/rollout_losses/{valid_or_test}/epoch_{epoch}"
                save_dir = Path(save_dir)
                os.makedirs(save_dir, exist_ok=True)
                stats_path = save_dir / f"{valid_or_test}_stats.csv"
                # If the file exists, clear it out
                if stats_path.exists():
                    print(f"Clearing out old stats file {stats_path}")
                    stats_path.unlink()
            batch_idx = 0

            with torch.autocast(
                device_type=self.device.type,
                enabled=self.enable_amp,
                dtype=self.amp_type,
            ):
                for j, batch in enumerate(dataloader):
                    lines = []
                    # Validation datasets don't automatically add metadata
                    start_time = time.time()

                    if self.start_traj > 0:
                        # Original shapes:
                        # input_fields: [B, T_in, H, D, W, C]
                        # output_fields: [B, T_out, H, D, W, C]

                        T_in = batch["input_fields"].shape[1]
                        T_out = batch["output_fields"].shape[1]
                        total_steps = T_in + T_out

                        # Check if we have enough trajectory steps
                        if self.start_traj + T_in > total_steps:
                            raise ValueError(
                                f"start_traj ({self.start_traj}) + T_in ({T_in}) exceeds "
                                f"total trajectory length ({total_steps})"
                            )

                        # Concatenate input and output to get full trajectory
                        # Shape: [B, T_in + T_out, H, D, W, C]
                        full_trajectory = torch.cat(
                            [batch["input_fields"], batch["output_fields"]], dim=1
                        )

                        # Extract new input starting from start_traj
                        # New input: steps [start_traj : start_traj + T_in]
                        batch["input_fields"] = full_trajectory[
                            :, self.start_traj : self.start_traj + T_in
                        ]

                        # Extract new output: remaining steps [start_traj + T_in : end]
                        # This will have length (T_in + T_out - start_traj - T_in) = (T_out - start_traj)
                        batch["output_fields"] = full_trajectory[
                            :, self.start_traj + T_in :
                        ]

                    # Rollout for length of target - fake pass if not evaluating on this node
                    # so that we get the right field names for reduction
                    y_pred, y_ref = self.rollout_model(
                        self.model,
                        batch,
                        self.formatter_dict[dset_name],
                        train=False,
                        fake_pass=(rank_assignment != self.sync_group_rank),
                    )
                    # Go through losses
                    model_time = time.time() - start_time
                    y_pred, y_ref = (
                        y_pred[..., batch["padded_field_mask"]],
                        y_ref[..., batch["padded_field_mask"]],
                    )
                    if self.num_samples > 1:
                        # y_pred: (B * num_samples, T, ...)
                        B = y_pred.shape[0] // self.num_samples
                        T = y_pred.shape[1]
                        # Split into num_samples tensors of shape (B, T, ...) -> (B, num_samples, T, ...)
                        original_shape = y_pred.shape
                        remaining_dims = original_shape[1:]

                        # Reshape to separate batch and ensemble dimensions
                        new_shape = (B, self.num_samples) + remaining_dims
                        # Ensure the tensor is contiguous if it isn't already
                        if not y_pred.is_contiguous():
                            y_pred = y_pred.contiguous()
                        y_pred = y_pred.view(new_shape)
                        expand_shape = (B, self.num_samples) + y_ref.shape[1:]
                        y_ref = y_ref.unsqueeze(1).expand(expand_shape)
                    assert (
                        y_ref.shape == y_pred.shape
                    ), f"Mismatching shapes between reference {y_ref.shape} and prediction {y_pred.shape}"

                    used_field_names = [
                        f
                        for i, f in enumerate(field_names)
                        if batch["padded_field_mask"][i]
                    ]
                    # Validation
                    for loss_fn in self.validation_suite:
                        # Mean over batch and time per field
                        if (
                            self.skip_spectral_metrics
                            and "spectr" in loss_fn.__class__.__name__
                        ):
                            continue
                        # SpreadSkill ratio can only be computed for probabilistic models
                        if (
                            "SpreadSkill" in loss_fn.__class__.__name__
                            and self.is_deterministic
                        ):
                            continue
                        # For the CRPS loss if we have too many ensemble members, we need to compute
                        # a memory-efficient version of the loss.
                        if (
                            "CRPS" in loss_fn.__class__.__name__
                            or "ES" in loss_fn.__class__.__name__
                        ) and not self.is_deterministic:
                            loss = {}
                            mem_efficient = self.num_samples > self.max_num_samples
                            for i in range(1, self.num_samples + 1):
                                name = (
                                    f"{loss_fn.__class__.__name__}_{i}"
                                    if i != self.num_samples
                                    else loss_fn.__class__.__name__
                                )
                                loss[name] = loss_fn(
                                    y_pred[:, :i],
                                    y_ref[:, :i],
                                    current_metadata,
                                    mem_efficient=mem_efficient,
                                )
                        # Same with spectral losses when we have multiple ensemble members
                        elif (
                            "spectr" in loss_fn.__class__.__name__
                            and self.num_samples > 1
                        ):
                            loss = None
                            for k in range(
                                0, self.num_samples, self.max_spectral_val_samples
                            ):  # Batch of self.max_spectral_val_samples ensemble members
                                y_pred_batch = y_pred[
                                    :, k : k + self.max_spectral_val_samples
                                ]
                                y_ref_batch = y_ref[
                                    :, k : k + self.max_spectral_val_samples
                                ]
                                loss_batch = loss_fn(
                                    y_pred_batch, y_ref_batch, current_metadata
                                )
                                if loss is None:
                                    # Initialize lists for each key
                                    loss = {key: [v] for key, v in loss_batch.items()}
                                else:
                                    for key, v in loss_batch.items():
                                        loss[key].append(v)
                                del loss_batch
                            for key in loss:
                                loss[key] = torch.cat(loss[key], dim=1)
                        elif (
                            "VRMSE" in loss_fn.__class__.__name__
                            and not self.is_deterministic
                        ):
                            for i in range(1, self.num_samples + 1):
                                name = (
                                    f"{loss_fn.__class__.__name__}_{i}"
                                    if i != self.num_samples
                                    else loss_fn.__class__.__name__
                                )
                                # We take the VRMSE between the GT and ensemble mean
                                loss[name] = loss_fn(
                                    y_pred[:, :i].mean(dim=1),
                                    y_ref[:, :i].mean(dim=1),
                                    current_metadata,
                                )
                        elif (
                            "SpreadSkill" in loss_fn.__class__.__name__
                            and not self.is_deterministic
                        ):
                            if self.num_samples > 1:
                                for i in range(2, self.num_samples + 1):
                                    name = (
                                        f"{loss_fn.__class__.__name__}_{i}"
                                        if i != self.num_samples
                                        else loss_fn.__class__.__name__
                                    )
                                    loss[name] = loss_fn(
                                        y_pred[:, :i], y_ref[:, :i], current_metadata
                                    )
                        else:
                            if self.is_deterministic:
                                loss = loss_fn(y_pred, y_ref, current_metadata)
                            else:
                                loss = loss_fn(
                                    y_pred.mean(dim=1),
                                    y_ref.mean(dim=1),
                                    current_metadata,
                                )
                        # Some losses return multiple values for efficiency
                        if not isinstance(loss, dict):
                            loss = {loss_fn.__class__.__name__: loss}
                        # Split the losses and update the logging dictionary
                        for k, v in loss.items():
                            if v.ndim == 2:
                                # For energy score
                                sub_loss = v.mean(0)
                                sub_loss = sub_loss.unsqueeze(-1).repeat(
                                    1, len(used_field_names)
                                )
                            elif v.ndim == 3:
                                # For CRPS that already ensembles over ensemble members
                                sub_loss = v.mean(0)  # Take the batch mean
                            elif v.ndim == 4:
                                sub_loss = v.mean(
                                    dim=(0, 1)
                                )  # Take the batch and ensemble mean
                            new_losses, new_time_logs = self.split_up_losses(
                                sub_loss, k, dset_name, used_field_names
                            )
                            # TODO get better way to include spectral error.
                            if (
                                k in long_time_metrics
                                or "spectral_error" in k
                                or "CRPS" in k
                                or "VRMSE" in k
                                or "ES" in k
                            ):
                                # time_logs |= new_time_logs
                                for k, v in new_time_logs.items():
                                    time_logs.setdefault(k, []).extend(
                                        v if isinstance(v, list) else [v]
                                    )
                            for loss_name, loss_value in new_losses.items():
                                inner_loss_dict[loss_name] = (
                                    inner_loss_dict.get(loss_name, 0.0)
                                    + loss_value / denom
                                )
                                # Let's just store the VRMSE since that's what I'm actually looking at on aggregate.
                                if "full_VRMSE_T=all" in loss_name:
                                    vrmse = loss_value.item()

                    # Lola-style metrics
                    if full:
                        # y_ref, y_pred: [1, N_ensemble, T, H, W, D, C]
                        spatial_dims = tuple(
                            range(-current_metadata.n_spatial_dims - 1, -1)
                        )
                        spatial = len(spatial_dims)
                        time_dim = min(spatial_dims) - 1
                        x = y_ref[0].movedim(time_dim, -2)
                        x_hat = y_pred[0].movedim(time_dim, -2)

                        es_loss = ES()
                        crps_loss = CRPS()
                        weighted_crps_loss = WeightedCRPS()

                        for t in range(x.shape[-2]):
                            # Compute energy score (which is over all fields)
                            if not self.is_deterministic:
                                pred_t = x_hat[..., t, :]
                                truth_t = x[0, ..., t, :]
                            else:
                                pred_t = x_hat[..., t, :]
                                truth_t = x[..., t, :]

                            es_by_ensemble = {}
                            for ensemble_i in self.ensemble_sizes_to_save:
                                if self.is_deterministic:
                                    # Handle deterministic case (ES = L2 Error)
                                    pred_t_subset = pred_t[None, ...]
                                else:
                                    pred_t_subset = pred_t[:ensemble_i]

                                # Compute the score
                                es_i = es_loss(
                                    predictions=pred_t_subset[None, :, None, ...],
                                    target=truth_t.view(1, 1, 1, *truth_t.shape).expand(
                                        1, ensemble_i, 1, *truth_t.shape
                                    ),
                                    metadata=current_metadata,
                                    mem_efficient=False,
                                ).item()
                                es_by_ensemble[ensemble_i] = es_i

                            for field in range(x.shape[-1]):
                                if not self.is_deterministic:
                                    u, v = (
                                        x[0, ..., t, field],
                                        x_hat[..., t, field],
                                    )  # Take just first sample for ground truth since they are all the same
                                else:
                                    u, v = (
                                        x[..., t, field],
                                        x_hat[..., t, field],
                                    )
                                # Moments (these don't change with ensemble size)
                                m1 = torch.mean(u)
                                m2 = torch.mean(u**2)

                                # Fourier analysis (ground truth doesn't change)
                                p_u, k = isotropic_power_spectrum(u, spatial=spatial)
                                if len(p_u.shape) == 2:
                                    p_u = torch.mean(p_u, dim=0)
                                bins = torch.logspace(
                                    k[0].log2(), -1.0, steps=4, base=2
                                )

                                # Storage for ensemble-dependent metrics
                                metrics_by_ensemble = {}

                                for ensemble_i in self.ensemble_sizes_to_save:
                                    if self.is_deterministic:
                                        v_subset = v[None, ...]
                                    else:
                                        v_subset = v[
                                            :ensemble_i
                                        ]  # use the first ensemble_i members

                                    # Spread
                                    if ensemble_i > 1:
                                        # see https://doi.org/10.1175/JHM-D-14-0008.1
                                        spread = torch.mean(
                                            torch.square(
                                                v_subset - torch.mean(v_subset, dim=0)
                                            )
                                        )
                                        spread = torch.sqrt(
                                            (ensemble_i + 1) / (ensemble_i - 1) * spread
                                        )
                                    else:
                                        spread = 0.0

                                    # Skill metrics
                                    se = torch.square(u - torch.mean(v_subset, dim=0))
                                    mse = torch.mean(se)
                                    rmse = torch.sqrt(mse)
                                    nrmse = torch.sqrt(mse / (torch.mean(u**2) + 1e-6))
                                    vrmse = torch.sqrt(mse / (torch.var(u) + 1e-6))

                                    # Spread_skill ratio
                                    spread_skill = (spread + 1e-3) / (rmse + 1e-3)

                                    # Fourier metrics
                                    p_v, _ = isotropic_power_spectrum(
                                        v_subset, spatial=spatial
                                    )
                                    p_v = torch.mean(p_v, dim=0)
                                    se_p = torch.square(1 - (p_v + 1e-6) / (p_u + 1e-6))
                                    rmse_p = torch.sqrt(torch.mean(se_p))

                                    fourier_extras = []
                                    for i in range(4):
                                        if i < 3:
                                            mask = torch.logical_and(
                                                bins[i] <= k, k <= bins[i + 1]
                                            )
                                        else:
                                            mask = bins[i] <= k
                                        fourier_extras.append(
                                            torch.sqrt(torch.mean(se_p[mask])).item()
                                        )

                                    extras = []
                                    ## Wasserstein
                                    w_uv = ot.lp.wasserstein_1d(
                                        u.flatten(),
                                        v_subset.flatten(),
                                        p=1.0,
                                    )
                                    extras.append(w_uv.item())
                                    ## Sliced EMD, but we don't do it anymore
                                    extras.append(None)

                                    # CRPS
                                    crps_i = crps_loss(
                                        predictions=v_subset[None, :, None, ..., None],
                                        target=u.view(1, 1, 1, *u.shape, 1).expand(
                                            1, ensemble_i, 1, *u.shape, 1
                                        ),
                                        metadata=current_metadata,
                                        mem_efficient=False,
                                    ).mean()

                                    # Weighted CRPS
                                    weighted_crps_i = weighted_crps_loss(
                                        predictions=v_subset[None, :, None, ..., None],
                                        target=u.view(1, 1, 1, *u.shape, 1).expand(
                                            1, ensemble_i, 1, *u.shape, 1
                                        ),
                                        metadata=current_metadata,
                                        mem_efficient=False,
                                    ).mean()

                                    # Store all metrics for this ensemble size
                                    metrics_by_ensemble[ensemble_i] = {
                                        "vrmse": vrmse.item(),
                                        "rmse": rmse.item(),
                                        "nrmse": nrmse.item(),
                                        "spread": (
                                            spread
                                            if isinstance(spread, float)
                                            else spread.item()
                                        ),
                                        "spread_skill": spread_skill.item(),
                                        "rmse_p": rmse_p.item(),
                                        "rmse_p_low": fourier_extras[0],
                                        "rmse_p_mid": fourier_extras[1],
                                        "rmse_p_high": fourier_extras[2],
                                        "rmse_p_sub": fourier_extras[3],
                                        "crps": crps_i.item(),
                                        "frequency_crps": 2 * weighted_crps_i.item()
                                        - crps_i.item(),  # assumes equal weights
                                        "weighted_crps": weighted_crps_i.item(),
                                        "energy_score": es_by_ensemble[ensemble_i],
                                        "wasserstein": extras[0],
                                        "emd": extras[1],
                                    }

                                # Store one row for each ensemble size
                                for ensemble_i in self.ensemble_sizes_to_save:
                                    if (
                                        ensemble_i in metrics_by_ensemble
                                    ):  # Only if we computed it
                                        ensemble_metrics = metrics_by_ensemble[
                                            ensemble_i
                                        ]

                                        # line = f"{runid},state,{compression:.1f},crps,{cfg.noise_emb_features},"
                                        # line += f"train,None,{cfg.val_eval.context},{cfg.val_eval.overlap},1.0,"
                                        line = f"{valid_or_test},{j},"
                                        line += f"{used_field_names[field]},{t},"
                                        line += (
                                            f"{ensemble_i},"  # Add ensemble size here
                                        )
                                        line += f"{m1},{m2},"
                                        line += f"{ensemble_metrics['spread']},{ensemble_metrics['spread_skill']},"
                                        line += f"{ensemble_metrics['rmse']},{ensemble_metrics['nrmse']},{ensemble_metrics['vrmse']},"
                                        line += f"{ensemble_metrics['rmse_p']},"

                                        # Add Fourier extras
                                        fourier_values = [
                                            ensemble_metrics["rmse_p_low"],
                                            ensemble_metrics["rmse_p_mid"],
                                            ensemble_metrics["rmse_p_high"],
                                            ensemble_metrics["rmse_p_sub"],
                                        ]
                                        line += ",".join(map(str, fourier_values)) + ","
                                        # Add CRPS value for this ensemble size
                                        line += f"{ensemble_metrics['crps']}" + ","
                                        line += (
                                            f"{ensemble_metrics['frequency_crps']}"
                                            + ","
                                        )
                                        line += (
                                            f"{ensemble_metrics['weighted_crps']}" + ","
                                        )

                                        # Add Energy Score value for this ensemble size
                                        line += (
                                            f"{ensemble_metrics['energy_score']}" + ","
                                        )

                                        # Add Wasserstein and Sliced EMD
                                        extra_values = [
                                            ensemble_metrics["wasserstein"],
                                            ensemble_metrics["emd"],
                                        ]
                                        line += ",".join(map(str, extra_values))

                                        line += "\n"
                                        lines.append(line)

                    total_time = time.time() - start_time
                    max_mem_GB = torch.cuda.max_memory_allocated() / 1024**3
                    max_mem_reserved_GB = torch.cuda.max_memory_reserved() / 1024**3

                    # Save metrics for new sample
                    # Save all_stats as a pickle file if this rank should save
                    if full:
                        if (
                            self.rank_in_sync_group == 0
                            and rank_assignment == self.sync_group_rank
                        ):
                            # Note: stats_path is a Path object now.
                            with FileLock(str(stats_path) + ".lock"):
                                with open(stats_path, "a") as f:  # Use append mode
                                    # Only write header if file is empty/new
                                    if stats_path.stat().st_size == 0:
                                        header = "split,sample_index,field,time,"
                                        header += "ensemble_size,"  # Add this column
                                        header += "m1,m2,spread,spread_skill,rmse,nrmse,vrmse,rmse_p,"
                                        header += "rmse_p_low,rmse_p_mid,rmse_p_high,rmse_p_sub,"
                                        header += "crps,frequency_crps,weighted_crps,"  # Single CRPS value per row
                                        header += "energy_score,"  # Single Energy Score value per row
                                        header += "wasserstein,emd"
                                        header += "\n"
                                        f.write(header)
                                    print(f"Sample {j}")
                                    f.writelines(lines)

                    # Only print out if local device actually doing something
                    if rank_assignment == self.sync_group_rank:
                        logger.info(
                            f"{valid_or_test}: {dset_name}, Batch {j+1}/{denom}, Rank {self.rank:>3}: Field-time-averaged VRMSE {vrmse:7.4f}, mem {max_mem_GB:5.2f} GB, mem reserved {max_mem_reserved_GB:5.2f} GB, total_time {total_time:5.3f}s, model {model_time:5.4f}s"
                        )
                    if torch.cuda.is_available():
                        torch.cuda.reset_peak_memory_stats()
                    count += 1
                    batch_idx += 1
                    if dataset.full_trajectory_mode:
                        if self.video_validation and count < 3:
                            rollout_length = batch["output_fields"].shape[1]
                            pred_idx, ref_idx = self._get_first_sample_idx(
                                y_pred, rollout_length
                            )
                            try:
                                make_video(
                                    y_pred[pred_idx],
                                    y_ref[ref_idx],
                                    current_metadata,
                                    self.viz_folder,
                                    f"{epoch}_{j}",  # For the file name
                                    field_name_overrides=used_field_names,  # Fields actually used
                                    size_multiplier=0.4,  # Shrinking for bulk runs, but visuals tuned around 1
                                )
                                gc.collect()
                            except CalledProcessError as e:
                                logger.warning(
                                    f"Error in making video due to FFMPEG: {e}. Skipping video."
                                )
                    if not full and count >= self.short_validation_length:
                        break

                # If there is a FSDP group, average loss over the group
                if self.sync_group:
                    for k, v in inner_loss_dict.items():
                        dist.all_reduce(
                            inner_loss_dict[k],
                            op=dist.ReduceOp.AVG,
                            group=self.sync_group,
                        )
                # Update the overall loss dict with the new loss
                for k, v in inner_loss_dict.items():
                    loss_dict[k] = loss_dict.get(k, 0.0) + inner_loss_dict[k]
                # Last batch plots - too much work to combine from batches - will be noisy
                if (
                    self.rank_in_sync_group == 0
                    and rank_assignment == self.sync_group_rank
                ):
                    if self.image_validation:
                        for plot_fn in validation_plots:
                            if (
                                self.skip_spectral_metrics
                                and "spectr" in plot_fn.__name__
                            ):
                                continue
                            rollout_length = batch["output_fields"].shape[1]
                            pred_idx, ref_idx = self._get_first_sample_idx(
                                y_pred, rollout_length
                            )
                            # Convert (0,) to slice [:], (0,0) to [:,0] since image plotting expects batch dimension
                            pred_idx = (
                                slice(None) if pred_idx == (0,) else (slice(None), 0)
                            )
                            ref_idx = (
                                slice(None) if ref_idx == (0,) else (slice(None), 0)
                            )
                            plot_fn(
                                y_pred[pred_idx],
                                y_ref[ref_idx],
                                current_metadata,
                                self.viz_folder,  # Temporary until we port over the resume logic
                                epoch,
                            )

                    if dataset.full_trajectory_mode:
                        # Only plot if we have more than one timestep, but then track loss over timesteps
                        if self.image_validation:
                            plot_all_time_metrics(
                                time_logs,
                                current_metadata,
                                self.viz_folder,
                                valid_or_test,
                                epoch,
                            )
                        if self.video_validation:
                            try:
                                # Clear GPU cache
                                torch.cuda.empty_cache()
                                # Force Python to delete unused variables
                                gc.collect()
                                rollout_length = batch["output_fields"].shape[1]
                                pred_idx, ref_idx = self._get_first_sample_idx(
                                    y_pred, rollout_length
                                )
                                make_video(
                                    y_pred[pred_idx],  # First sample only in batch
                                    y_ref[ref_idx],  # First sample only in batch
                                    current_metadata,
                                    self.viz_folder,
                                    epoch,  # For the file name
                                    field_name_overrides=used_field_names,  # Fields actually used
                                    size_multiplier=0.4,  # Shrinking for bulk runs, but visuals tuned around 1
                                )
                            except CalledProcessError as e:
                                logger.warning(
                                    f"Error in making video due to FFMPEG: {e}. Skipping video."
                                )

        if self.is_distributed:
            # Wait for all ranks to finish
            logger.debug(f"Rank {self.rank} waiting for barrier")
            dist.barrier()
            logger.debug(f"Rank {self.rank} passed barrier")
            # NOTE - Cleaner to just divide through instead of finding ddp group which may not exist
            for k, v in loss_dict.items():
                dist.all_reduce(
                    loss_dict[k],
                    op=dist.ReduceOp.SUM,
                )
                loss_dict[k] = loss_dict[k] / self.sync_group_size
        # Single score validation loss is average of all losses on the training metric
        validation_loss = sum(
            [
                loss_dict[
                    f"{metadata.dataset_name}/full_{self.loss_fn.__class__.__name__}_T=all"
                ].item()
                for metadata in metadatas
            ]
        ) / len(metadatas)
        loss_dict = {f"{valid_or_test}_{k}": v.item() for k, v in loss_dict.items()}
        loss_dict |= plot_dicts
        # Misc metrics
        loss_dict["param_norm"] = param_norm(self.model.parameters())

        return validation_loss, loss_dict

    def train_one_epoch(
        self, epoch: int, dataloader: DataLoader
    ) -> tuple[float, dict[str, Any]]:
        """Train the model for one epoch by looping over the dataloader."""
        self.model.train()
        if self.scale_noise:
            if epoch < 5:
                noise_scale = 0.0
            elif epoch < 10:
                noise_scale = 0.1
            else:
                noise_scale = 1.0
        else:
            noise_scale = 1.0
        epoch_loss = 0.0
        avg_grad_norm = torch.tensor(0.0).to(self.device)
        last_grad_norm = torch.tensor(0.0).to(self.device)
        train_logs: dict[str, Any] = {}
        batch_start = time.time()
        interval_start = time.time()
        # When using grad acculuation, it makes sense to zero gradient outside first, then after optimizer step
        self.optimizer.zero_grad()  # Set to none now default
        overall_batch_queue = []
        current_batch_queue = []
        data_iter = iter(dataloader)
        i = 0
        while i < len(dataloader):
            # If reuse batches is on, we want to cache grad_acc_steps worth of
            # batches and reuse them. We want the order to be new -> cached -> new
            # Current last batch will just be new.
            if self.reuse_batches and (
                len(overall_batch_queue) >= 2 * self.grad_acc_steps
                or len(current_batch_queue) > 0
            ):
                # Reuse the batch
                if len(current_batch_queue) == 0:
                    # If grad acc > 1, shuffle so we have unique batches at least
                    if self.grad_acc_steps > 1:
                        shuffle(overall_batch_queue)
                    current_batch_queue = overall_batch_queue[: self.grad_acc_steps]
                    overall_batch_queue = overall_batch_queue[self.grad_acc_steps :]
                batch = current_batch_queue.pop(0)
            else:
                batch = next(data_iter)
                if self.reuse_batches:
                    overall_batch_queue.append(batch.copy())
            batch["padded_field_mask"] = batch["padded_field_mask"].to(
                self.device, non_blocking=True
            )
            # Update grad if we're not using distribution
            update_grad = (i + 1) % self.grad_acc_steps == 0
            with (
                nullcontext()
                if (update_grad or self.distribution_type == "local")
                else self.model.no_sync()
            ):
                with torch.autocast(
                    device_type=self.device.type,
                    enabled=self.enable_amp,
                    dtype=self.amp_type,
                ):
                    data_time = time.time() - batch_start
                    current_metadata = batch["metadata"]
                    dset_name = current_metadata.dataset_name
                    y_pred, y_ref = self.rollout_model(
                        self.model,
                        batch,
                        self.formatter_dict[dset_name],
                        train_rollout_steps=self.train_rollout_steps,
                        noise_scale=noise_scale,
                    )
                    # If T > self.minimum_context, then optimize only the predictions with the minimum context
                    # By default this is just removing the zero-context prediction.
                    if y_pred.shape[1] > self.minimum_context:
                        y_ref = y_ref[:, self.minimum_context :]
                        y_pred = y_pred[:, self.minimum_context :]
                    forward_time = time.time() - batch_start - data_time
                    if self.num_samples > 1:
                        # y_pred: (B * num_samples, T, ...)
                        B = y_pred.shape[0] // self.num_samples
                        T = y_pred.shape[1]
                        # Split into num_samples tensors of shape (B, T, ...) -> (B, num_samples, T, ...)
                        original_shape = y_pred.shape
                        remaining_dims = original_shape[1:]

                        # Reshape to separate batch and ensemble dimensions
                        new_shape = (B, self.num_samples) + remaining_dims
                        # Ensure the tensor is contiguous if it isn't already
                        if not y_pred.is_contiguous():
                            y_pred = y_pred.contiguous()
                        y_pred = y_pred.view(new_shape)
                        # Use expand for memory-efficient broadcasting (no data copy)
                        expand_shape = (B, self.num_samples) + y_ref.shape[1:]
                        y_ref = y_ref.unsqueeze(1).expand(expand_shape)
                    assert (
                        y_ref.shape == y_pred.shape
                    ), f"Mismatching shapes between reference {y_ref.shape} and prediction {y_pred.shape}"
                    loss = (
                        self.loss_multiplier
                        * self.loss_fn(y_pred, y_ref, current_metadata).mean()
                        / self.grad_acc_steps
                    )
                    del y_pred, y_ref  # Free up a little before the BW pass
                self.grad_scaler.scale(loss).backward()
                backward_time = time.time() - batch_start - forward_time - data_time
            # On update_grad steps, we actually perform the steps
            if update_grad:
                if self.clip_gradient > 0 or self.gradient_log_level > 0:
                    self.grad_scaler.unscale_(self.optimizer)
                if self.gradient_log_level == 1:
                    if hasattr(self.model, "sharding_strategy"):
                        last_grad_norm = get_grad_norm_fsdp(
                            self.model,
                            self.rank,
                            self.world_size,
                            self.model.sharding_strategy,
                        )
                    else:
                        last_grad_norm = get_grad_norm_local(self.model)
                        avg_grad_norm += last_grad_norm.detach() / (
                            len(dataloader) / self.grad_acc_steps
                        )
                if self.clip_gradient > 0:
                    if hasattr(self.model, "clip_grad_norm_"):
                        self.model.clip_grad_norm_(
                            self.clip_gradient,
                            norm_type=2.0,
                        )
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.clip_gradient, norm_type=2.0
                        )
                self.grad_scaler.step(self.optimizer)
                self.grad_scaler.update()
                self.optimizer.zero_grad()  # Set to none is now default\
                if self.lr_scheduler_per_step and self.lr_scheduler:
                    self.lr_scheduler.step()
            total_time = time.time() - batch_start
            optimizer_time = total_time - forward_time - backward_time - data_time
            # Syncing for all reduce anyway so may as well compute synchornous metrics
            epoch_loss += (self.grad_acc_steps * loss.detach()) / len(
                dataloader
            )  # Unscale loss for accurate measure.
            max_mem_GB = torch.cuda.max_memory_allocated() / 1024**3
            if i % self.log_interval == 0:
                timing = (time.time() - interval_start) / self.log_interval
                interval_start = time.time()
                logger.info(
                    f"Epoch {epoch:>4}, Batch {i+1}/{len(dataloader)}, "
                    f"Rank {self.rank:>3}, SyncStep: {update_grad}:\n\t"
                    f"Data: {current_metadata.dataset_name:<32}, "
                    f"loss {(self.grad_acc_steps * loss.item()) ** 0.5:7.4f}, "
                    # f"loss {(self.grad_acc_steps * loss.item()):7.4f}, "
                    f"mem {max_mem_GB:5.2f} GB, total_time {timing:5.3f}s, "
                    f"data {data_time:5.4f}s, fwd {forward_time:5.3f}s, "
                    f"bw {backward_time:5.3f}s, opt {optimizer_time:5.3f}s"
                )
                # Log times and memory stats to wandb - I don't trust wandb numbers
                if torch.cuda.is_available():
                    torch.cuda.reset_peak_memory_stats()
            # Log times and memory stats to wandb - I don't trust wandb numbers
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
            batch_start = time.time()
            # Log elapsed times in train_log - NOTE: only accurate if cuda syncing, but can be interpretted either way
            train_logs["avg_data_loading_time"] = train_logs.get(
                "data_loading_time", 0
            ) + data_time / len(dataloader)
            train_logs["avg_forward_time"] = train_logs.get(
                "forward_time", 0
            ) + forward_time / len(dataloader)
            train_logs["avg_backward_time"] = train_logs.get(
                "backward_time", 0
            ) + backward_time / len(dataloader)
            train_logs["avg_optimizer_time"] = train_logs.get(
                "optimizer_time", 0
            ) + optimizer_time / len(dataloader)
            train_logs["avg_time_per_step"] = train_logs.get(
                "avg_time_per_step", 0
            ) + total_time / len(dataloader)
            train_logs["peak_memory"] = max(
                train_logs.get("peak_memory", 0), max_mem_GB
            )
            i += 1
        train_logs["train_loss"] = epoch_loss
        if self.gradient_log_level >= 1:
            train_logs["avg_grad_norm"] = avg_grad_norm.item()
        if self.lr_scheduler:
            if not self.lr_scheduler_per_step:
                self.lr_scheduler.step()

            # Enhanced LR logging for staged learning
            if (
                hasattr(self.optimizer, "param_groups")
                and len(self.optimizer.param_groups) > 1
            ):
                # Extract LRs by parameter group position
                for i, group in enumerate(self.optimizer.param_groups):
                    # Group 0: new params, Group 1: common params (if it exists)
                    param_type = "new" if i == 0 else "common"
                    train_logs[f"lr_{param_type}"] = group["lr"]

                # Log additional information for staged learning
                if hasattr(self.lr_scheduler, "warmup_epochs"):
                    current_epoch = epoch - self.start_epoch + 1
                    train_logs["lr_warmup_progress"] = min(
                        1.0, current_epoch / self.lr_scheduler.warmup_epochs
                    )
                    train_logs["lr_warmup_complete"] = (
                        current_epoch >= self.lr_scheduler.warmup_epochs
                    )

                # Keep backward compatibility - use the new params LR as the main LR
                train_logs["lr"] = train_logs["lr_new"]

                # Log staging info
                if self.rank == 0:  # Only log from rank 0 to avoid spam
                    logger.info(
                        f"Epoch {epoch}: LR - New: {train_logs['lr_new']:.9f}, Common: {train_logs['lr_common']:.6f}"
                    )
            else:
                # Single parameter group - use existing logic
                train_logs["lr"] = self.lr_scheduler.get_last_lr()[0]
        return epoch_loss, train_logs

    def validate_if_necessary(
        self,
        epoch: int,
        one_step_dataloaders: list[DataLoader],
        rollout_dataloaders: list[DataLoader],
        valid_or_test: Literal["valid", "test"] = "valid",
    ):
        """Check what type of validate/rollouts we need to do for a given epoch.

        Parameters
        ----------
        epoch: int
            The current epoch. Used for logging and saving checkpoints.
        one_step_dataloaders: list[DataLoader]
            List of dataloaders for one step validation
        rollout_dataloaders: list[DataLoader]
            List of dataloaders for rollout validation
        valid_or_test: str
            String to indicate if we are validating or testing. Options are "valid" or "test"
        """
        is_test = valid_or_test == "test"  # Check if test
        val_loss, rollout_val_loss = None, None
        # First do one step checks = frequency, last epoch, or test. Only do full validation on last epoch or test
        if (
            epoch % self.val_frequency == 0
            or epoch >= self.max_epoch
            or is_test
            or epoch == 1
        ):
            logger.info(
                f"Epoch {epoch}/{self.max_epoch}: starting {valid_or_test} validation"
            )
            val_loss, loss_dict = self.validation_loop(
                one_step_dataloaders,
                valid_or_test=valid_or_test,
                full=(epoch >= self.max_epoch or is_test),
                epoch=epoch,
            )
            logger.info(
                f"Epoch {epoch}/{self.max_epoch}: {valid_or_test} loss {val_loss}"
            )
            loss_dict |= {f"{valid_or_test}": val_loss, "epoch": epoch}

            if self.wandb_logging and self.rank == 0:
                wandb.log(loss_dict)

        # Rollout if frequency, last epoch, or if this is the test set
        if (
            epoch % self.rollout_val_frequency == 0
            or epoch >= self.max_epoch
            or is_test
        ):
            logger.info(
                f"Epoch {epoch}/{self.max_epoch}: starting rollout {valid_or_test} validation"
            )
            rollout_val_loss, rollout_val_loss_dict = self.validation_loop(
                rollout_dataloaders,
                valid_or_test=f"rollout_{valid_or_test}",
                full=(epoch >= self.max_epoch or is_test),
                epoch=epoch,
            )
            logger.info(
                f"Epoch {epoch}/{self.max_epoch}: rollout {valid_or_test} loss {rollout_val_loss}"
            )
            rollout_val_loss_dict |= {
                f"rollout_{valid_or_test}": rollout_val_loss,
                "epoch": epoch,
            }
            if self.wandb_logging and self.rank == 0:
                wandb.log(rollout_val_loss_dict)
        return val_loss, rollout_val_loss

    def train(self):
        """Run training, validation and test. The training is run for multiple epochs."""
        checkpoint_future = None
        val_loss = self.start_val_loss
        train_dataloader = self.datamodule.train_dataloader(self.sampling_rank)
        for epoch in range(
            self.start_epoch, self.max_epoch + 1
        ):  # I like 1 indexing for epochs
            # NOTE - only update train sampler because we want to sample same valid data every time
            if self.is_distributed:
                train_dataloader.sampler.set_epoch(epoch)
            # Empty mem caches before train loop
            torch.cuda.empty_cache()
            gc.collect()
            logger.info(f"Epoch {epoch}/{self.max_epoch}: starting training")
            train_loss, train_logs = self.train_one_epoch(epoch, train_dataloader)
            logger.info(
                f"Epoch {epoch}/{self.max_epoch}: training loss {train_loss:.4f}"
            )
            train_logs |= {"train": train_loss, "epoch": epoch}
            if self.wandb_logging and self.rank == 0:
                wandb.log(train_logs)
            # Empty mem caches before val
            torch.cuda.empty_cache()
            gc.collect()

            # Recreate loader every time so we're using same data in val
            val_dataloders = self.datamodule.val_dataloaders(
                replicas=self.sync_group_size,
                rank=self.rank_in_sync_group,
                full=(epoch == self.max_epoch),
            )
            rollout_val_dataloaders = self.datamodule.rollout_val_dataloaders(
                replicas=self.sync_group_size, rank=self.rank_in_sync_group
            )
            maybe_val_loss, rollout_loss = self.validate_if_necessary(
                epoch,
                val_dataloders,
                rollout_val_dataloaders,
            )
            val_loss = maybe_val_loss if maybe_val_loss is not None else val_loss
            if checkpoint_future is not None:
                logger.debug(
                    f"Wait for previous checkpointing {checkpoint_future} to complete."
                )
                checkpoint_future.result()  # Make sure previous checkpoint has finished before starting next.
            # Save "last" every epoch plus various intervals/best results
            checkpoint_future = self.save_model_if_necessary(
                epoch, val_loss, last=(epoch == self.max_epoch)
            )
        # Do test validation
        test_dataloaders = self.datamodule.test_dataloaders()
        rollout_test_dataloaders = self.datamodule.rollout_test_dataloaders()
        self.validate_if_necessary(
            epoch,
            test_dataloaders,
            rollout_test_dataloaders,
            valid_or_test="test",
        )

    def validate(self):
        """Run validation and test. This is a stand alone path"""
        val_dataloders = self.datamodule.val_dataloaders()
        rollout_val_dataloaders = self.datamodule.rollout_val_dataloaders()
        test_dataloaders = self.datamodule.test_dataloaders()
        rollout_test_dataloaders = self.datamodule.rollout_test_dataloaders()
        # Run validation and test
        self.validate_if_necessary(
            self.max_epoch + 1,
            val_dataloders,
            rollout_val_dataloaders,
        )
        self.validate_if_necessary(
            self.max_epoch + 1,
            test_dataloaders,
            rollout_test_dataloaders,
            valid_or_test="test",
        )
