#!/usr/bin/env python

import argparse
import random
from functools import partial

import dawgz
import numpy as np
import ot
from filelock import FileLock
from omegaconf import DictConfig, open_dict
from tqdm import tqdm

import wandb
from lola.autoencoder import get_autoencoder
from lola.crps import CRPS
from lola.data import field_postprocess
from lola.emulation import decode_traj, emulate_rollout, emulate_surrogate, encode_traj
from lola.fourier import isotropic_power_spectrum
from lola.hydra import compose
from lola.plot import draw_movie
from lola.surrogate import MaskedSurrogate


def train(
    runid: str,
    cfg: DictConfig,
    wandb_project: str = "lola-sm",
    target: str = "state",
):
    import os
    import re
    from pathlib import Path

    import torch
    import torch.distributed as dist
    from einops import rearrange
    from omegaconf import OmegaConf, open_dict
    from torch.nn.parallel import DistributedDataParallel
    from tqdm import trange

    import wandb
    from lola.data import (
        MiniWellDataset,
        field_preprocess,
        find_hdf5,
        get_dataloader,
        get_well_inputs,
        get_well_multi_dataset,
    )
    from lola.emulation import random_context_mask
    from lola.nn.utils import load_state_dict
    from lola.optim import get_optimizer, safe_gd_step
    from lola.surrogate import RegressionLoss, get_surrogate
    from lola.utils import randseed

    # Check if we're running in a distributed environment
    use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ

    if use_ddp:
        print("Initializing DDP...")
        # DDP
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        device_id = os.environ.get("LOCAL_RANK", rank)
        device_id = int(device_id)
        print(
            f"DDP initialized: rank={rank}, world_size={world_size}, device_id={device_id}"
        )
    else:
        # Single GPU
        rank = 0
        world_size = 1
        device_id = 0
        print("Single GPU mode")

    device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device(device)

    # Performance
    torch.set_float32_matmul_precision("high")

    # Config
    batch_size = cfg.train.batch_size
    assert cfg.train.epoch_size % batch_size == 0
    assert batch_size % (cfg.train.accumulation * world_size) == 0
    assert cfg.valid.epoch_size % cfg.valid.batch_size == 0
    assert cfg.valid.batch_size % world_size == 0

    if cfg.ae_run:
        space = re.search(r"f\d+c\d+", cfg.ae_run).group()
        # Load autoencoder if it exists
        if (Path(cfg.ae_run) / "autoencoder").exists():
            ae_cfg = OmegaConf.load(cfg.ae_run / "autoencoder/config.yaml").ae

            state = torch.load(
                cfg.ae_run / "autoencoder/state.pth",
                weights_only=True,
                map_location=device,
            )

            autoencoder = get_autoencoder(**ae_cfg)
            autoencoder.load_state_dict(state)
            autoencoder.to(device)
            autoencoder.requires_grad_(False)
            autoencoder.eval()
        elif (Path(cfg.ae_run)).exists():
            ae_cfg = OmegaConf.load(Path(cfg.ae_run) / "config.yaml").ae

            state = torch.load(
                Path(cfg.ae_run) / "state.pth",
                weights_only=True,
                map_location=device,
            )

            autoencoder = get_autoencoder(**ae_cfg)
            autoencoder.load_state_dict(state)
            autoencoder.to(device)
            autoencoder.requires_grad_(False)
            autoencoder.eval()
    else:
        space = "pixel"
        autoencoder = None

    surrogate_run_path = None
    if cfg.finetune:
        if hasattr(cfg, "surrogate_run") and cfg.surrogate_run:
            surrogate_run_path = cfg.surrogate_run
            if "euler" in cfg.dataset.name:
                assert (
                    "euler" in cfg.surrogate_run
                ), f"Surrogate run {cfg.surrogate_run} does not match dataset {cfg.dataset.name}."
            if "rayleigh" in cfg.dataset.name:
                assert (
                    "rayleigh" in cfg.surrogate_run
                ), f"Surrogate run {cfg.surrogate_run} does not match dataset {cfg.dataset.name}."
            if "shear_flow" in cfg.dataset.name:
                assert (
                    "shear_flow" in cfg.surrogate_run
                ), f"Surrogate run {cfg.surrogate_run} does not match dataset {cfg.dataset.name}."

        if surrogate_run_path:
            surrogate_runpath = Path(surrogate_run_path)
            surrogate_cfg = OmegaConf.load(surrogate_runpath / "config.yaml")
            state = torch.load(
                surrogate_runpath / f"{target}.pth",
                weights_only=True,
                map_location=device,
            )
            new_state_dict = {}
            for k, v in state.items():
                name = k.replace("module.", "") if k.startswith("module.") else k
                new_state_dict[name] = v
            if hasattr(surrogate_cfg, "surrogate"):
                print("Loading surrogate model...")
                surrogate = get_surrogate(**surrogate_cfg.surrogate)
                surrogate.load_state_dict(new_state_dict)
                surrogate.to(device)
                surrogate.eval()
            else:
                raise ValueError(
                    f"Could not find surrogate configuration in the provided run {surrogate_runpath}."
                )

    if cfg.finetune:
        runname = f"{runid}_{cfg.dataset.name}FT_E{cfg.train.epochs}_{space}_{cfg.surrogate.name}"
    else:
        runname = f"{runid}_{cfg.dataset.name}_{space}_{cfg.surrogate.name}"

    runpath = Path(f"{cfg.server.storage}/runs/sm/{runname}")
    runpath = runpath.expanduser().resolve()
    runpath.mkdir(parents=True, exist_ok=True)

    with open_dict(cfg):
        cfg.name = runname
        cfg.path = str(runpath)
        cfg.seed = randseed(runid)

        if cfg.ae_run:
            cfg.ae_run = os.path.realpath(os.path.expanduser(cfg.ae_run), strict=True)
            if "euler" in cfg.dataset.name:
                assert (
                    "euler" in cfg.ae_run
                ), f"Surrogate run {cfg.surrogate_run} does not match dataset {cfg.dataset.name}."
            if "rayleigh" in cfg.dataset.name:
                assert (
                    "rayleigh" in cfg.ae_run
                ), f"Surrogate run {cfg.surrogate_run} does not match dataset {cfg.dataset.name}."
            if "shear_flow" in cfg.dataset.name:
                assert (
                    "shear_flow" in cfg.ae_run
                ), f"Surrogate run {cfg.surrogate_run} does not match dataset {cfg.dataset.name}."

    if rank == 0 and cfg.ae_run:
        os.symlink(cfg.ae_run, runpath / "autoencoder")

    if use_ddp:
        dist.barrier(device_ids=[device_id])

    # Data
    if cfg.ae_run:
        files = {
            split: [
                file
                for physic in cfg.dataset.physics
                for file in find_hdf5(
                    path=runpath / "autoencoder/cache" / physic / split,
                    include_filters=cfg.dataset.include_filters,
                )
            ]
            for split in ("train", "valid")
        }

        dataset = {
            split: MiniWellDataset.from_files(
                files=files[split],
                steps=cfg.trajectory.length,
                stride=cfg.trajectory.stride,
            )
            for split in ("train", "valid")
        }
    else:
        dataset = {
            split: get_well_multi_dataset(
                path=(
                    cfg.server.datasets
                    if not hasattr(cfg.dataset, "dataset_path")
                    else cfg.dataset.dataset_path
                ),
                physics=cfg.dataset.physics,
                split=split,
                steps=cfg.trajectory.length,
                min_dt_stride=cfg.trajectory.stride,
                max_dt_stride=cfg.trajectory.stride,
                include_filters=cfg.dataset.include_filters,
                augment=cfg.dataset.augment,
            )
            for split in ("train", "valid")
        }

    # Get rollout dataset for evaluation
    rollout_dataset = {
        "valid": get_well_multi_dataset(
            path=(
                cfg.server.datasets
                if not hasattr(cfg.dataset, "dataset_path")
                else cfg.dataset.dataset_path
            ),
            physics=cfg.dataset.physics,
            split="valid",
            steps=-1,
            include_filters=cfg.dataset.include_filters,
            augment=[s for s in cfg.dataset.augment if "random" not in s],
        ),
        "test": get_well_multi_dataset(
            path=(
                cfg.server.datasets
                if not hasattr(cfg.dataset, "dataset_path")
                else cfg.dataset.dataset_path
            ),
            physics=cfg.dataset.physics,
            split="test",
            steps=-1,
            include_filters=cfg.dataset.include_filters,
            augment=[s for s in cfg.dataset.augment if "random" not in s],
        ),
    }

    train_loader, valid_loader = [
        get_dataloader(
            dataset=dataset[split],
            batch_size=(
                cfg.train.batch_size // cfg.train.accumulation // world_size
                if split == "train"
                else cfg.valid.batch_size // world_size
            ),
            shuffle=True,
            infinite=True,
            num_workers=cfg.compute.cpus_per_gpu,
            rank=rank,
            world_size=world_size,
            seed=cfg.seed,
        )
        for split in ("train", "valid")
    ]

    if cfg.ae_run:
        preprocess = lambda x: x
    else:
        preprocess = partial(
            field_preprocess,
            mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
            std=torch.as_tensor(cfg.dataset.stats.std, device=device),
            transform=cfg.dataset.transform,
        )

    rollout_preprocess = partial(
        field_preprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    postprocess = partial(
        field_postprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    if hasattr(cfg.dataset, "dimensions"):
        spatial = len(cfg.dataset.dimensions)
    else:
        spatial = 2

    x, label = get_well_inputs(next(valid_loader))
    x = rearrange(x, "B L ... C -> B C L ...")

    # Model, optimizer & scheduler
    with open_dict(cfg):
        cfg.surrogate.channels = x.shape[1]
        cfg.surrogate.label_features = label.shape[1]
        cfg.surrogate.spatial = len(cfg.dataset.dimensions) + 1

    if not cfg.finetune:
        surrogate = get_surrogate(**cfg.surrogate).to(device)

    surrogate_loss = RegressionLoss(
        losses=["mse"] if cfg.ae_run else ["vmse"],
    ).to(device)
    crps_loss = CRPS()

    # Stem
    if cfg.fork.run is None:
        counter = {
            "epoch": 0,
            "update_samples": 0,
            "update_steps": 0,
        }
    else:
        stem = wandb.Api().run(path=cfg.fork.run)
        stem_name = Path(stem.config["path"]).name
        stem_path = Path(f"{cfg.server.storage}/runs/crps/{stem_name}")
        stem_path = stem_path.expanduser().resolve()
        stem_state = torch.load(
            stem_path / f"{cfg.fork.target}.pth", weights_only=True, map_location=device
        )

        counter = {
            "epoch": stem.summary["_step"] + 1,
            "update_samples": stem.summary["train/samples"],
            "update_steps": stem.summary["train/update_steps"],
        }

    if use_ddp:
        surrogate = DistributedDataParallel(
            module=surrogate,
            device_ids=[device_id],
            find_unused_parameters=True,
        )

    surrogate_type = (
        type(surrogate.module) if hasattr(surrogate, "module") else type(surrogate)
    )
    is_masked_surrogate = surrogate_type == MaskedSurrogate

    optimizer, scheduler = get_optimizer(
        params=surrogate.parameters(),
        epochs=cfg.train.epochs,
        **cfg.optim,
    )

    # W&B
    if rank == 0:
        OmegaConf.save(cfg, runpath / "config.yaml")

    if rank == 0:
        run = wandb.init(
            entity=cfg.wandb.entity,
            project=wandb_project,
            group=cfg.dataset.name,
            id=runid,
            name=runname,
            config=OmegaConf.to_container(cfg),
        )

    # Training loop
    if rank == 0:
        epochs = trange(cfg.train.epochs, ncols=88, ascii=True)
    else:
        epochs = range(cfg.train.epochs)

    best_valid_loss = float("inf")

    # For validation on the same set of trajectories - Get a random set of indices up to the length of valid_dataset
    num_val_indices = getattr(cfg.val_eval, "num_val_indices", 20)
    valid_dataset_len = len(rollout_dataset["valid"])
    # Use a separate Random instance with fixed seed for reproducible validation indices
    val_rng = random.Random(42)
    indices = val_rng.sample(
        range(valid_dataset_len), min(num_val_indices, valid_dataset_len)
    )
    print("Validation indices", indices)
    if "euler" in cfg.dataset.name:
        num_test_indices = 1000
        cfg.val_eval.start = 0
    elif "rayleigh" in cfg.dataset.name:
        num_test_indices = 175
        cfg.val_eval.start = 5
    elif "shear_flow" in cfg.dataset.name:
        num_test_indices = 112
        cfg.val_eval.start = 5
    elif "cns_2D" in cfg.dataset.name:
        num_test_indices = 4000
        cfg.val_eval.start = 0
    else:
        raise ValueError(
            f"No default number of test indices for dataset {cfg.dataset.name}."
        )
    test_dataset_len = len(rollout_dataset["test"])
    test_rng = random.Random(123)
    test_indices = test_rng.sample(
        range(test_dataset_len), min(num_test_indices, test_dataset_len)
    )
    print("Test indices", test_indices)
    record = cfg.val_eval.record

    for epoch_idx, epoch in enumerate(epochs):
        ## Train
        surrogate.train()

        losses, grads = [], []

        for i in range(
            cfg.train.accumulation * cfg.train.epoch_size // cfg.train.batch_size
        ):
            x, label = get_well_inputs(next(train_loader), device=device)
            label = label.to(x.dtype)
            x = preprocess(x)
            x = rearrange(x, "B L ... C -> B C L ...")

            mask = random_context_mask(x, **cfg.trajectory.context)

            if (i + 1) % cfg.train.accumulation == 0:
                if is_masked_surrogate:
                    y = surrogate(x * mask, mask=mask, label=label)
                else:
                    assert (
                        cfg.surrogate.name == "FFNO"
                    ), f"Unexpected surrogate {cfg.surrogate.name} without MaskedSurrogate."
                    assert hasattr(cfg.surrogate, "time_history")
                    assert hasattr(cfg.surrogate, "time_future")
                    history_x = x[
                        :,
                        :,
                        -(
                            cfg.surrogate.time_history + cfg.surrogate.time_future
                        ) : -cfg.surrogate.time_future,
                        ...,
                    ]
                    x = x[:, :, -cfg.surrogate.time_future :, ...]
                    y = surrogate(history_x, label=label)

                loss = surrogate_loss(x, y) * cfg.train.loss_multiplier
                loss_acc = loss / cfg.train.accumulation
                loss_acc.backward()

                grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip)
                grads.append(grad_norm)
            else:
                with surrogate.no_sync():
                    if is_masked_surrogate:
                        y = surrogate(x * mask, mask=mask, label=label)
                    else:
                        assert (
                            cfg.surrogate.name == "FFNO"
                        ), f"Unexpected surrogate {cfg.surrogate.name} without MaskedSurrogate."
                        assert hasattr(cfg.surrogate, "time_history")
                        assert hasattr(cfg.surrogate, "time_future")
                        history_x = x[
                            :,
                            :,
                            -(
                                cfg.surrogate.time_history + cfg.surrogate.time_future
                            ) : -cfg.surrogate.time_future,
                            ...,
                        ]
                        x = x[:, :, -cfg.surrogate.time_future :, ...]
                        y = surrogate(history_x, label=label)

                    loss = surrogate_loss(x, y) * cfg.train.loss_multiplier
                    loss_acc = loss / cfg.train.accumulation
                    loss_acc.backward()

            losses.append(loss.detach())

        losses = torch.stack(losses)
        grads = torch.stack(grads)

        # Gather losses and grads only if using DDP
        if use_ddp:
            if rank == 0:
                losses_list = [torch.empty_like(losses) for _ in range(world_size)]
                grads_list = [torch.empty_like(grads) for _ in range(world_size)]
            else:
                losses_list = None
                grads_list = None

            dist.gather(losses, losses_list, dst=0)
            dist.gather(grads, grads_list, dst=0)

            if rank == 0:
                losses = torch.cat(losses_list).cpu()
                grads = torch.cat(grads_list).cpu()
        else:
            losses = losses.cpu()
            grads = grads.cpu()

        if rank == 0:
            logs = {}
            logs["train/loss/mean"] = losses.mean().item()
            logs["train/loss/std"] = losses.std().item()
            logs["train/grad_norm/mean"] = grads.mean().item()
            logs["train/grad_norm/std"] = grads.std().item()
            logs["train/learning_rate"] = optimizer.param_groups[0]["lr"]

        if use_ddp:
            del losses_list, grads_list
        del losses, grads

        ## Eval
        surrogate.eval()

        ## Stats for rollout
        if (
            (epoch_idx % cfg.val_eval.interval == 0)
            and cfg.val_eval.enabled
            and epoch_idx != 0
        ) or (epoch_idx == cfg.train.epochs - 1):
            if rank == 0:
                print(f"Epoch {epoch_idx}, performing validation rollout evaluation...")

            if epoch_idx == cfg.train.epochs - 1:
                rollout_splits = ["valid", "test"]
            else:
                rollout_splits = ["valid"]

            for split in rollout_splits:
                if rank == 0:
                    print(f"Rollout evaluation on {split} set...")

                # Initialize metrics dictionary with field-specific and aggregate tracking
                all_metrics = {}
                # Get number of fields
                field_names = cfg.dataset.fields
                num_fields = len(field_names)

                # Get ensemble sizes to track (only the ones we'll save)
                ensemble_sizes_to_save = getattr(
                    cfg.val_eval,
                    "ensemble_sizes_to_save",
                    list(range(1, cfg.val_eval.samples + 1)),
                )

                if cfg.val_eval.samples not in ensemble_sizes_to_save:
                    ensemble_sizes_to_save.append(cfg.val_eval.samples)
                    ensemble_sizes_to_save = sorted(ensemble_sizes_to_save)
                assert (
                    max(ensemble_sizes_to_save) <= cfg.val_eval.samples
                ), "Ensemble sizes to save should not exceed the number of samples."

                # Initialize metrics for each field and aggregate
                metric_names = [
                    "vrmse",
                    "rmse",
                    "nrmse",
                    "spread",
                    "spread_skill",
                    "rmse_p",
                    "rmse_p_low",
                    "rmse_p_mid",
                    "rmse_p_high",
                    "rmse_p_sub",
                    "crps",
                ]

                for metric_name in metric_names:
                    # Only track ensemble sizes we'll save
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            all_metrics[f"{metric_name}{ensemble_idx}_all"] = []
                            for field in range(num_fields):
                                all_metrics[
                                    f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                ] = []

                current_indices = indices if split == "valid" else test_indices

                for idx, index in enumerate(
                    tqdm(current_indices, desc="Rollouts", ncols=88, ascii=True)
                ):
                    # Video plotting condition
                    video_plotting = (record > 0) and (
                        (idx in [0, 1, 2]) or (index in [82, 30, 170])
                    )
                    # Get data
                    x, label = get_well_inputs(
                        rollout_dataset[split][index], device=device
                    )
                    label = label.to(x.dtype)
                    x = x[
                        max(
                            0,
                            cfg.val_eval.start
                            - (cfg.val_eval.context - 1) * cfg.trajectory.stride,
                        ) :: cfg.trajectory.stride
                    ]
                    x = rollout_preprocess(x)
                    x = rearrange(x, "L ... C -> C L ...")

                    with torch.no_grad():
                        z = encode_traj(autoencoder, x)

                    compression = x.numel() / z.numel()

                    method = "surrogate"
                    emulate = lambda mask, z_obs, noise, i: emulate_surrogate(
                        surrogate,
                        mask,
                        z_obs,
                        label=label,
                    )

                    with torch.no_grad():
                        # Emulate trajectory
                        z_hat = emulate_rollout(
                            emulate,
                            z,
                            window=cfg.trajectory.length,
                            rollout=z.shape[1],
                            context=(
                                cfg.val_eval.context
                                if is_masked_surrogate
                                else cfg.surrogate.time_history
                            ),
                            overlap=(
                                cfg.val_eval.overlap
                                if is_masked_surrogate
                                else cfg.surrogate.time_history
                            ),
                            batch=1,
                            masked=is_masked_surrogate,
                        )

                        if "euler" in cfg.dataset.name:
                            chunk_size = 128
                        elif "gravity" in cfg.dataset.name:
                            chunk_size = 128
                        else:
                            chunk_size = 256

                        x_hat = decode_traj(
                            autoencoder,
                            z_hat,
                            batched=True,
                            noisy=False,
                            chunk_size=chunk_size,
                        )

                    del z_hat

                    ## Postprocess
                    x = postprocess(x, dim=-spatial - 2)
                    x_hat = postprocess(x_hat, dim=-spatial - 2)

                    if split == "test" and rank == 0:
                        lines = []
                    # Compute metrics
                    for field in range(x.shape[0]):
                        for t in range(cfg.val_eval.context - 1, x.shape[1]):
                            true_t = (
                                t - cfg.val_eval.context + 1
                            ) * cfg.trajectory.stride
                            u, v = x[field, t], x_hat[:, field, t]

                            # Moments (these don't change with ensemble size)
                            m1 = torch.mean(u)
                            m2 = torch.mean(u**2)

                            # Fourier analysis (ground truth doesn't change)
                            p_u, k = isotropic_power_spectrum(u, spatial=spatial)
                            bins = torch.logspace(k[0].log2(), -1.0, steps=4, base=2)

                            # Storage for ensemble-dependent metrics
                            metrics_by_ensemble = {}

                            # Iterate over ensemble sizes
                            for ensemble_i in ensemble_sizes_to_save:
                                v_subset = v[
                                    :ensemble_i
                                ]  # Use first ensemble_i samples

                                # Spread
                                if ensemble_i > 1:
                                    # see https://doi.org/10.1175/JHM-D-14-0008.1
                                    spread = torch.mean(
                                        torch.square(
                                            v_subset - torch.mean(v_subset, dim=0)
                                        )
                                    )
                                    spread = torch.sqrt(
                                        (ensemble_i + 1) / (ensemble_i - 1) * spread
                                    )
                                else:
                                    spread = 0.0

                                # Skill metrics
                                se = torch.square(u - torch.mean(v_subset, dim=0))
                                mse = torch.mean(se)
                                rmse = torch.sqrt(mse)
                                nrmse = torch.sqrt(mse / (torch.mean(u**2) + 1e-6))
                                vrmse = torch.sqrt(mse / (torch.var(u) + 1e-6))

                                # Spread_skill ratio
                                spread_skill = (spread + 1e-3) / (rmse + 1e-3)

                                # Fourier metrics
                                p_v, _ = isotropic_power_spectrum(
                                    v_subset, spatial=spatial
                                )
                                p_v = torch.mean(p_v, dim=0)
                                se_p = torch.square(1 - (p_v + 1e-6) / (p_u + 1e-6))
                                rmse_p = torch.sqrt(torch.mean(se_p))

                                fourier_extras = []
                                for i in range(4):
                                    if i < 3:
                                        mask = torch.logical_and(
                                            bins[i] <= k, k <= bins[i + 1]
                                        )
                                    else:
                                        mask = bins[i] <= k
                                    fourier_extras.append(
                                        torch.sqrt(torch.mean(se_p[mask])).item()
                                    )

                                extras = []
                                ## Wasserstein
                                w_uv = ot.lp.wasserstein_1d(
                                    u.flatten(),
                                    v_subset.flatten(),
                                    p=1.0,
                                )
                                extras.append(w_uv.item())
                                ## Sliced EMD (only makes sense for density)
                                if "density" in cfg.dataset.fields[field] and False:
                                    coo = torch.cartesian_prod(
                                        *(
                                            torch.linspace(0, 1, size, device=u.device)
                                            for size in u.shape
                                        )
                                    )
                                    edm = ot.sliced.sliced_wasserstein_distance(
                                        coo,
                                        coo,
                                        a=u.flatten() / u.sum(),
                                        b=v_subset.mean(dim=0).flatten()
                                        / v_subset.mean(dim=0).sum(),
                                        p=1.0,
                                        n_projections=16,
                                        seed=42,
                                    )

                                    extras.append(edm.item())
                                else:
                                    extras.append(None)

                                # CRPS
                                crps_i = crps_loss(
                                    predictions=v_subset[None, :, None, None, :],
                                    target=u[None, None, None, :],
                                    mask=None,
                                    mem_efficient=False,
                                ).mean()

                                # Store all metrics for this ensemble size
                                metrics_by_ensemble[ensemble_i] = {
                                    "vrmse": vrmse.item(),
                                    "rmse": rmse.item(),
                                    "nrmse": nrmse.item(),
                                    "spread": (
                                        spread
                                        if isinstance(spread, float)
                                        else spread.item()
                                    ),
                                    "spread_skill": spread_skill.item(),
                                    "rmse_p": rmse_p.item(),
                                    "rmse_p_low": fourier_extras[0],
                                    "rmse_p_mid": fourier_extras[1],
                                    "rmse_p_high": fourier_extras[2],
                                    "rmse_p_sub": fourier_extras[3],
                                    "crps": crps_i.item(),
                                    "wasserstein": extras[0],
                                    "emd": extras[1],
                                }

                            # For CSV output (test split) - store one row per ensemble size
                            if split == "test":
                                # Store one row for each ensemble size
                                for ensemble_i in ensemble_sizes_to_save:
                                    if (
                                        ensemble_i in metrics_by_ensemble
                                    ):  # Only if we computed it
                                        ensemble_metrics = metrics_by_ensemble[
                                            ensemble_i
                                        ]

                                        line = f"{runid},state,{compression:.1f},crps,"
                                        line += f"train,None,{cfg.val_eval.context},{cfg.val_eval.overlap},1.0,"
                                        line += f"{split},{index},{cfg.val_eval.start},{cfg.seed},"
                                        line += f"{field_names[field]},{(t - cfg.val_eval.context + 1) * cfg.trajectory.stride},False,"
                                        line += (
                                            f"{ensemble_i},"  # Add ensemble size here
                                        )
                                        line += f"{m1},{m2},"
                                        line += f"{ensemble_metrics['spread']},{ensemble_metrics['spread_skill']},"
                                        line += f"{ensemble_metrics['rmse']},{ensemble_metrics['nrmse']},{ensemble_metrics['vrmse']},"
                                        line += f"{ensemble_metrics['rmse_p']},"

                                        # Add Fourier extras
                                        fourier_values = [
                                            ensemble_metrics["rmse_p_low"],
                                            ensemble_metrics["rmse_p_mid"],
                                            ensemble_metrics["rmse_p_high"],
                                            ensemble_metrics["rmse_p_sub"],
                                        ]
                                        line += ",".join(map(str, fourier_values)) + ","
                                        # Add CRPS value for this ensemble size
                                        line += f"{ensemble_metrics['crps']}" + ","

                                        # Add Wasserstein and Sliced EMD
                                        extra_values = [
                                            ensemble_metrics["wasserstein"],
                                            ensemble_metrics["emd"],
                                        ]
                                        line += ",".join(map(str, extra_values))

                                        # Add label parameters
                                        if hasattr(label, "tolist"):
                                            line += (
                                                f",{','.join(map(str, label.tolist()))}"
                                            )

                                        line += "\n"
                                        lines.append(line)

                            # Store metrics for logging (per-field and field-averaged)
                            for ensemble_i in ensemble_sizes_to_save:
                                ensemble_idx = (
                                    f"_{ensemble_i}"
                                    if ensemble_i != cfg.val_eval.samples
                                    else ""
                                )

                                # Store per-field metrics for this ensemble size
                                for metric_name, value in metrics_by_ensemble[
                                    ensemble_i
                                ].items():
                                    if metric_name not in ["wasserstein", "emd"]:
                                        all_metrics[
                                            f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                        ].append((true_t, value))

                    # After processing all samples, compute field-averaged metrics across all timesteps
                    for ensemble_i in ensemble_sizes_to_save:
                        ensemble_idx = (
                            f"_{ensemble_i}"
                            if ensemble_i != cfg.val_eval.samples
                            else ""
                        )

                        for metric_name in metric_names:
                            # Get all field keys for this metric and ensemble size
                            field_keys = [
                                f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                for field in range(num_fields)
                            ]

                            # Check if all field keys exist and have data
                            if all(
                                key in all_metrics and all_metrics[key]
                                for key in field_keys
                            ):
                                # Get all timesteps from the first field (they should all be the same)
                                timesteps = [t for t, _ in all_metrics[field_keys[0]]]

                                # Create the field-averaged list
                                field_averaged_values = []

                                for i, timestep in enumerate(timesteps):
                                    # Collect values from all fields for this timestep index
                                    field_values = []
                                    for field_key in field_keys:
                                        if i < len(all_metrics[field_key]):
                                            _, value = all_metrics[field_key][
                                                i
                                            ]  # Get value at index i
                                            field_values.append(value)

                                    # Average across fields for this timestep
                                    if (
                                        len(field_values) == num_fields
                                    ):  # Ensure we have all fields
                                        avg_value = sum(field_values) / len(
                                            field_values
                                        )
                                        field_averaged_values.append(
                                            (timestep, avg_value)
                                        )

                                # Store the field-averaged values
                                all_metrics[f"{metric_name}{ensemble_idx}_all"] = (
                                    field_averaged_values
                                )

                    if split == "test" and rank == 0:
                        # Save
                        stats_path = (
                            runpath / "test_stats.csv"
                        )  # Single file for all samples
                        with FileLock(str(stats_path) + ".lock"):
                            with open(stats_path, "a") as f:  # Use write mode
                                # Only write header if file is empty/new
                                if stats_path.stat().st_size == 0:
                                    header = "run,target,compression,method,settings,guidance,context,overlap,speed,"
                                    header += (
                                        "split,index,start,seed,field,time,relative,"
                                    )
                                    header += "ensemble_size,"  # Add this column
                                    header += "m1,m2,spread,spread_skill,rmse,nrmse,vrmse,rmse_p,"
                                    header += (
                                        "rmse_p_low,rmse_p_mid,rmse_p_high,rmse_p_sub,"
                                    )
                                    header += "crps,"  # Single CRPS value per row
                                    header += "wasserstein,emd"
                                    if (
                                        hasattr(label, "tolist")
                                        and len(label.tolist()) > 0
                                    ):
                                        param_names = getattr(
                                            cfg.dataset,
                                            "param_names",
                                            [
                                                f"param_{i}"
                                                for i in range(len(label.tolist()))
                                            ],
                                        )
                                        header += f",{','.join(param_names)}"
                                    header += "\n"
                                    f.write(header)
                                f.writelines(lines)

                    if (record > 0) and (
                        (idx in [0, 1, 2]) or (index in [82, 30, 170])
                    ):
                        if x.shape[-1] < x.shape[-2]:
                            x, x_hat = x.mT, x_hat.mT
                        frames = torch.stack((x, *x_hat[:record]))
                        frames = rearrange(frames, "N C L H W -> L N C H W")
                        # Get rid of the first conditioning
                        frames = frames[cfg.val_eval.context :]

                        # Ensure the directory exists before saving the movie
                        movie_dir = runpath / f"epoch_{epoch_idx}"
                        movie_dir.mkdir(parents=True, exist_ok=True)
                        draw_movie(
                            frames,
                            file=(
                                movie_dir
                                / f"{runname}_{target}_{split}_{index:06d}_{cfg.val_eval.start:03d}_{cfg.val_eval.context}_{cfg.val_eval.overlap}_{cfg.seed}.mp4"
                            ),
                            fps=4.0 / cfg.trajectory.stride,
                            isolate={2},
                        )

                        del frames
                    del x, x_hat
                # Log aggregated rollout metrics across all samples
                if rank == 0 and len(all_metrics[f"{metric_names[0]}_all"]) > 0:
                    # Helper function to split by time ranges
                    def split_by_time_range(data, ranges):
                        """Split data into time ranges. ranges is list of (min, max) tuples."""
                        splits = [[] for _ in ranges]
                        for timestep, value in data:
                            for i, (t_min, t_max) in enumerate(ranges):
                                if t_min <= timestep <= t_max:
                                    splits[i].append(value)
                                    break
                        return splits

                    # Helper functions using NumPy
                    def compute_mean(data_list):
                        values = np.array(
                            [v for idx, v in data_list if idx >= cfg.val_eval.context]
                        )
                        return float(np.mean(values))

                    def compute_median(data_list):
                        values = np.array(
                            [v for idx, v in data_list if idx >= cfg.val_eval.context]
                        )
                        return float(np.median(values))

                    # Define time ranges: 1-40, 41-80, 81-200
                    if "rayleigh" in cfg.dataset.name:
                        time_ranges = [(1, 40), (41, 80), (81, 200)]
                        range_names = ["1_40", "41_80", "81_200"]
                    elif "euler" in cfg.dataset.name:
                        time_ranges = [(1, 20), (21, 60), (61, 100)]
                        range_names = ["1_20", "21_60", "61_100"]
                    elif "shear_flow" in cfg.dataset.name:
                        time_ranges = [(1, 40), (41, 80), (81, 200)]
                        range_names = ["1_40", "41_80", "81_200"]
                    elif "cns" in cfg.dataset.name:
                        time_ranges = [(1, 4), (4, 10), (10, 21)]
                        range_names = ["1_4", "4_10", "10_21"]
                    else:
                        raise ValueError(
                            f"No time ranges defined for dataset {cfg.dataset.name}."
                        )

                    # Log aggregate metrics (averaged across all fields) - for ALL ensemble sizes
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            for metric_name in metric_names:
                                metric_key = f"{metric_name}{ensemble_idx}_all"
                                if (
                                    metric_key in all_metrics
                                    and all_metrics[metric_key]
                                ):
                                    data = all_metrics[metric_key]
                                    logs[
                                        f"{split}/rollout/{metric_name}{ensemble_idx}_mean"
                                    ] = compute_mean(data)
                                    logs[
                                        f"{split}/rollout/{metric_name}{ensemble_idx}_median"
                                    ] = compute_median(data)

                                    # Time range breakdowns
                                    splits = split_by_time_range(data, time_ranges)
                                    for i, range_name in enumerate(range_names):
                                        if splits[i]:
                                            logs[
                                                f"{split}/rollout/{metric_name}{ensemble_idx}_mean_{range_name}"
                                            ] = sum(splits[i]) / len(splits[i])
                                            logs[
                                                f"{split}/rollout/{metric_name}{ensemble_idx}_median_{range_name}"
                                            ] = float(np.median(splits[i]))

                    # Log per-field metrics - for ALL ensemble sizes
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            for field in range(num_fields):
                                for metric_name in metric_names:
                                    metric_key = f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                    if (
                                        metric_key in all_metrics
                                        and all_metrics[metric_key]
                                    ):
                                        data = all_metrics[metric_key]
                                        logs[
                                            f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_mean"
                                        ] = compute_mean(data)
                                        logs[
                                            f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_median"
                                        ] = compute_median(data)

                                        # Time range breakdowns per field
                                        splits = split_by_time_range(data, time_ranges)
                                        for i, range_name in enumerate(range_names):
                                            if splits[i]:
                                                logs[
                                                    f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_mean_{range_name}"
                                                ] = sum(splits[i]) / len(splits[i])
                                                logs[
                                                    f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_median_{range_name}"
                                                ] = float(np.median(splits[i]))

        losses = []

        with torch.no_grad():
            for _ in range(cfg.valid.epoch_size // cfg.valid.batch_size):
                x, label = get_well_inputs(next(valid_loader), device=device)
                label = label.to(x.dtype)
                x = preprocess(x)
                x = rearrange(x, "B L ... C -> B C L ...")

                mask = random_context_mask(x, **cfg.trajectory.context)

                if is_masked_surrogate:
                    y = surrogate(x * mask, mask=mask, label=label)
                else:
                    history_x = x[
                        :,
                        :,
                        -(
                            cfg.surrogate.time_history + cfg.surrogate.time_future
                        ) : -cfg.surrogate.time_future,
                        ...,
                    ]
                    x = x[:, :, -cfg.surrogate.time_future :, ...]
                    y = surrogate(history_x, label=label)

                loss = surrogate_loss(x, y) * cfg.train.loss_multiplier
                losses.append(loss)

        losses = torch.stack(losses)

        if use_ddp:
            if rank == 0:
                losses_list = [torch.empty_like(losses) for _ in range(world_size)]
            else:
                losses_list = None

            dist.gather(losses, losses_list, dst=0)

            if rank == 0:
                losses = torch.cat(losses_list).cpu()
        else:
            losses = losses.cpu()

        if rank == 0:
            logs["valid/loss/mean"] = losses.mean().item()
            logs["valid/loss/std"] = losses.std().item()

            epochs.set_postfix(
                lt=logs["train/loss/mean"],
                lv=logs["valid/loss/mean"],
            )

            run.log(logs, step=epoch)

        if use_ddp:
            del losses_list
        del losses

        ## LR scheduler
        scheduler.step()

        ## Checkpoint
        if rank == 0:
            state = surrogate.state_dict()

            # Save complete checkpoint with optimizer and scheduler states
            checkpoint = {
                "model_state_dict": state,
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "epoch": counter["epoch"],
                "best_valid_loss": best_valid_loss,
                "counter": counter,
                # Add any other training state you want to preserve
            }
            torch.save(checkpoint, runpath / "checkpoint.pth")

            # Also save just the model weights for compatibility
            torch.save(state, runpath / "state.pth")

            if logs["valid/loss/mean"] < best_valid_loss:
                best_valid_loss = logs["valid/loss/mean"]
                torch.save(state, runpath / "state_best.pth")
                torch.save(checkpoint, runpath / "checkpoint_best.pth")

            del state, checkpoint

        if use_ddp:
            dist.barrier(device_ids=[device_id])

    # W&B
    if rank == 0:
        run.finish()

    # DDP
    if use_ddp:
        dist.destroy_process_group()


if __name__ == "__main__":
    # Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("overrides", nargs="*", type=str)
    parser.add_argument("--debug", action="store_true", help="Debug mode")
    parser.add_argument("--finetune", action="store_true", help="Finetune mode")

    args = parser.parse_args()

    # Config
    if args.debug:
        cfg = compose(
            config_file="./experiments/configs/train_surrogate_debug.yaml",
            overrides=args.overrides,
        )
        wandb_project = "lola_sm_debug"
    elif args.finetune:
        cfg = compose(
            config_file="./experiments/configs/finetune_surrogate.yaml",
            overrides=args.overrides,
        )
        print("Using finetuning config")
        wandb_project = "lola_sm"
    else:
        cfg = compose(
            config_file="./experiments/configs/train_surrogate.yaml",
            overrides=args.overrides,
        )
        wandb_project = "lola_sm"
    # Job
    runid = wandb.util.generate_id()
    if args.debug:
        runid = "debug_" + runid

    if cfg.compute.nodes > 1:
        interpreter = f"torchrun --nnodes {cfg.compute.nodes} --nproc-per-node {cfg.compute.gpus} --rdzv_backend=c10d --rdzv_endpoint=$SLURMD_NODENAME:12345 --rdzv_id=$SLURM_JOB_ID"
    else:
        interpreter = (
            f"torchrun --nnodes 1 --nproc-per-node {cfg.compute.gpus} --standalone"
        )

    if args.debug:
        backend = "async"
        with open_dict(cfg):
            cfg.debug = True
    else:
        backend = "slurm"
        with open_dict(cfg):
            cfg.debug = False
    dawgz.schedule(
        dawgz.job(
            f=partial(train, runid, cfg),
            name=f"sm {runid}",
            nodes=cfg.compute.nodes,
            cpus=cfg.compute.cpus_per_gpu * cfg.compute.gpus,
            gpus=cfg.compute.gpus,
            ram=cfg.compute.ram,
            time=cfg.compute.time,
            partition=cfg.server.partition,
            constraint=cfg.server.constraint,
            exclude=cfg.server.exclude,
        ),
        name=f"training sm {runid}",
        backend=backend,
        interpreter=interpreter,
        env=[
            "export OMP_NUM_THREADS=" + f"{cfg.compute.cpus_per_gpu}",
            "export WANDB_SILENT=true",
            "export XDG_CACHE_HOME=$HOME/.cache",
        ],
    )
