#!/usr/bin/env python

import argparse
import random
from functools import partial

import dawgz
import numpy as np
import ot
from filelock import FileLock
from omegaconf import DictConfig, open_dict
from tqdm import tqdm

import wandb
from lola.autoencoder import get_autoencoder
from lola.crps import CRPS
from lola.data import field_postprocess
from lola.emulation import (
    decode_traj,
    emulate_diffusion,
    emulate_rollout,
    emulate_surrogate,
    encode_traj,
)
from lola.fourier import isotropic_power_spectrum
from lola.hydra import compose
from lola.optim import get_staged_optimizer
from lola.plot import draw_movie
from lola.surrogate import get_surrogate
from lola.utils import load_common_weights


def train(
    runid: str,
    cfg: DictConfig,
    wandb_project: str = "lola_prob",
    target: str = "state",
):
    import os
    import re
    from pathlib import Path

    import torch
    import torch.distributed as dist
    from einops import rearrange
    from omegaconf import OmegaConf, open_dict
    from torch.nn.parallel import DistributedDataParallel
    from tqdm import trange

    import wandb
    from lola.data import (
        MiniWellDataset,
        field_preprocess,
        find_hdf5,
        get_dataloader,
        get_well_inputs,
        get_well_multi_dataset,
    )
    from lola.diffusion import DenoiserLoss, get_denoiser
    from lola.emulation import random_context_mask
    from lola.nn.utils import load_state_dict
    from lola.optim import get_optimizer, safe_gd_step
    from lola.utils import randseed

    # Check if we're running in a distributed environment
    use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ

    if use_ddp:
        print("Initializing DDP...")
        # DDP
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        device_id = os.environ.get("LOCAL_RANK", rank)
        device_id = int(device_id)
        print(
            f"DDP initialized: rank={rank}, world_size={world_size}, device_id={device_id}"
        )
    else:
        # Single GPU
        rank = 0
        world_size = 1
        device_id = 0
        print("Single GPU mode")

    device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device(device)

    # Performance
    torch.set_float32_matmul_precision("high")

    # Config
    batch_size = cfg.train.batch_size
    assert cfg.train.epoch_size % cfg.train.batch_size == 0
    assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0
    assert cfg.valid.epoch_size % cfg.valid.batch_size == 0
    assert cfg.valid.batch_size % world_size == 0

    # Load surrogate model if fine-tuning from a surrogate
    surrogate_run_path = None
    if cfg.load_surrogate:
        if cfg.surrogate_run:
            surrogate_run_path = cfg.surrogate_run

    if surrogate_run_path:
        surrogate_runpath = Path(surrogate_run_path)
        surrogate_cfg = OmegaConf.load(surrogate_runpath / "config.yaml")
        state = torch.load(
            surrogate_runpath / f"{target}.pth", weights_only=True, map_location=device
        )
        new_state_dict = {}
        for k, v in state.items():
            name = k.replace("module.", "") if k.startswith("module.") else k
            new_state_dict[name] = v
        if hasattr(surrogate_cfg, "surrogate"):
            print("Loading surrogate model...")
            surrogate = get_surrogate(**surrogate_cfg.surrogate)
            surrogate.load_state_dict(new_state_dict)
            surrogate.to(device)
            surrogate.eval()
        else:
            raise ValueError(
                f"Could not find surrogate configuration in the provided run {surrogate_runpath}."
            )

        # Also load autoencoder
        if (surrogate_runpath / "autoencoder").exists():
            ae_cfg = OmegaConf.load(surrogate_runpath / "autoencoder/config.yaml").ae

            state = torch.load(
                surrogate_runpath / "autoencoder/state.pth",
                weights_only=True,
                map_location=device,
            )

            autoencoder = get_autoencoder(**ae_cfg)
            autoencoder.load_state_dict(state)
            autoencoder.to(device)
            autoencoder.requires_grad_(False)
            autoencoder.eval()
        else:
            autoencoder = None

        if surrogate_cfg.ae_run:
            space = re.search(r"f\d+c\d+", surrogate_cfg.ae_run).group()
        else:
            space = "pixel"
    else:
        surrogate = None
        # Autoencoder
        if cfg.load_ae and (Path(cfg.ae_run)).exists():
            ae_cfg = OmegaConf.load(Path(cfg.ae_run) / "config.yaml").ae

            state = torch.load(
                Path(cfg.ae_run) / "state.pth",
                weights_only=True,
                map_location=device,
            )

            autoencoder = get_autoencoder(**ae_cfg)
            autoencoder.load_state_dict(state)
            autoencoder.to(device)
            autoencoder.requires_grad_(False)
            autoencoder.eval()
        else:
            autoencoder = None

        if cfg.ae_run:
            space = re.search(r"f\d+c\d+", cfg.ae_run).group()
        else:
            space = "pixel"

    if cfg.load_surrogate:
        runname = f"{runid}_{cfg.dataset.name}_{space}_{cfg.denoiser.name}_{cfg.staged_training.threshold_lr}_{cfg.optim.learning_rate}"
    else:
        assert surrogate is None
        runname = f"scratch_{runid}_{cfg.dataset.name}_{space}_{cfg.denoiser.name}_{cfg.optim.learning_rate}"

    runpath = Path(f"{cfg.server.storage}/runs/dm/{runname}")
    runpath = runpath.expanduser().resolve()
    runpath.mkdir(parents=True, exist_ok=True)

    with open_dict(cfg):
        cfg.name = runname
        cfg.path = str(runpath)
        cfg.seed = randseed(runid)

        if surrogate is not None:
            if surrogate_cfg.ae_run:
                ae_run_path = surrogate_cfg.ae_run
                cfg.ae_run = os.path.realpath(
                    os.path.expanduser(ae_run_path), strict=True
                )
        else:
            if cfg.ae_run:
                cfg.ae_run = os.path.realpath(
                    os.path.expanduser(cfg.ae_run), strict=True
                )

    if rank == 0 and cfg.ae_run:
        os.symlink(cfg.ae_run, runpath / "autoencoder")

    if use_ddp:
        dist.barrier(device_ids=[device_id])

    counter = {
        "epoch": 0,
        "update_samples": 0,
        "update_steps": 0,
    }

    # Data
    if cfg.ae_run:
        files = {
            split: [
                file
                for physic in cfg.dataset.physics
                for file in find_hdf5(
                    path=runpath / "autoencoder/cache" / physic / split,
                    include_filters=cfg.dataset.include_filters,
                )
            ]
            for split in ("train", "valid")
        }

        dataset = {
            split: MiniWellDataset.from_files(
                files=files[split],
                steps=cfg.trajectory.length,
                stride=cfg.trajectory.stride,
            )
            for split in ("train", "valid")
        }
    else:
        dataset = {
            split: get_well_multi_dataset(
                path=cfg.server.datasets,
                physics=cfg.dataset.physics,
                split=split,
                steps=cfg.trajectory.length,
                min_dt_stride=cfg.trajectory.stride,
                max_dt_stride=cfg.trajectory.stride,
                include_filters=cfg.dataset.include_filters,
                augment=cfg.dataset.augment,
            )
            for split in ("train", "valid")
        }

    # Get rollout dataset for evaluation
    rollout_dataset = {
        "valid": get_well_multi_dataset(
            path=cfg.server.datasets,
            physics=cfg.dataset.physics,
            split="valid",
            steps=-1,
            include_filters=cfg.dataset.include_filters,
            augment=[s for s in cfg.dataset.augment if "random" not in s],
        ),
        "test": get_well_multi_dataset(
            path=cfg.server.datasets,
            physics=cfg.dataset.physics,
            split="test",
            steps=-1,
            include_filters=cfg.dataset.include_filters,
            augment=[s for s in cfg.dataset.augment if "random" not in s],
        ),
    }

    train_loader, valid_loader = [
        get_dataloader(
            dataset=dataset[split],
            batch_size=(
                cfg.train.batch_size // cfg.train.accumulation // world_size
                if split == "train"
                else cfg.valid.batch_size // world_size
            ),
            shuffle=True,
            infinite=True,
            num_workers=cfg.compute.cpus_per_gpu,
            rank=rank,
            world_size=world_size,
            seed=cfg.seed,
        )
        for split in ("train", "valid")
    ]

    if cfg.ae_run:
        preprocess = lambda x: x
    else:
        preprocess = partial(
            field_preprocess,
            mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
            std=torch.as_tensor(cfg.dataset.stats.std, device=device),
            transform=cfg.dataset.transform,
        )

    rollout_preprocess = partial(
        field_preprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    postprocess = partial(
        field_postprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    if hasattr(cfg.dataset, "dimensions"):
        spatial = len(cfg.dataset.dimensions)
    else:
        spatial = 2

    x, label = get_well_inputs(next(valid_loader))
    x = rearrange(x, "B L ... C -> B C L ...")

    # Model, optimizer & scheduler
    with open_dict(cfg):
        cfg.denoiser.channels = x.shape[1]
        cfg.denoiser.label_features = label.shape[1]
        cfg.denoiser.spatial = len(cfg.dataset.dimensions) + 1
        cfg.denoiser.masked = True

    denoiser = get_denoiser(**cfg.denoiser).to(device)
    denoiser_loss = DenoiserLoss(**cfg.denoiser.loss).to(device)
    num_params = sum(p.numel() for p in denoiser.parameters())
    print("Number of parameters:", num_params)
    if surrogate is not None:
        num_params_initial = sum(p.numel() for p in surrogate.parameters())
        print("Number of parameters initial surrogate:", num_params_initial)

    if surrogate is not None:
        model_info = load_common_weights(
            denoiser.backbone,
            surrogate.state_dict(),
            strict=True,
        )

    crps_loss = CRPS()

    if cfg.fork.run:
        stem_path = Path(cfg.fork.run).expanduser().resolve()
        stem_state = torch.load(
            stem_path / f"{cfg.fork.target}.pth", weights_only=True, map_location=device
        )

        load_state_dict(denoiser, stem_state, strict=cfg.fork.strict)

        del stem_state

    if use_ddp:
        denoiser = DistributedDataParallel(
            module=denoiser,
            device_ids=[device_id],
        )
        model_for_optim = denoiser.module
    else:
        model_for_optim = denoiser

    if surrogate is not None:
        # Get surrogate state dict for identifying common parameters
        surrogate_state = surrogate.state_dict()

        # Create staged optimizer and scheduler
        staged_cfg = cfg.get("staged_training", {})
        threshold_lr = staged_cfg.get("threshold_lr", 1e-5)
        common_params_kwargs = staged_cfg.get("common_params_kwargs", {})
        new_params_kwargs = staged_cfg.get("new_params_kwargs", {})
        common_scheduler = staged_cfg.get("common_scheduler", "cosine")
        common_warmup = staged_cfg.get("common_warmup", 0)
        new_scheduler = staged_cfg.get("new_scheduler", "cosine")
        new_warmup = staged_cfg.get("new_warmup", 0)

        optimizer, scheduler = get_staged_optimizer(
            model=model_for_optim.backbone,
            surrogate_state_dict=surrogate_state,
            threshold_lr=threshold_lr,
            common_params_kwargs=common_params_kwargs,
            new_params_kwargs=new_params_kwargs,
            common_scheduler=common_scheduler,
            common_warmup=common_warmup,
            new_scheduler=new_scheduler,
            new_warmup=new_warmup,
            epochs=cfg.train.epochs,
            **cfg.optim,
        )
    else:
        optimizer, scheduler = get_optimizer(
            params=denoiser.parameters(),
            epochs=cfg.train.epochs,
            **cfg.optim,
        )

    # W&B
    if rank == 0:
        OmegaConf.save(cfg, runpath / "config.yaml")

    if rank == 0:
        run = wandb.init(
            entity=cfg.wandb.entity,
            project="lola_prob" if wandb_project is None else wandb_project,
            group=cfg.dataset.name,
            id=runid,
            name=runname,
            config=OmegaConf.to_container(cfg),
        )

    # Training loop
    if rank == 0:
        epochs = trange(cfg.train.epochs, ncols=88, ascii=True)
    else:
        epochs = range(cfg.train.epochs)

    best_valid_loss = float("inf")

    # For validation on the same set of trajectories - Get a random set of indices up to the length of valid_dataset
    num_val_indices = getattr(cfg.val_eval, "num_val_indices", 20)
    valid_dataset_len = len(rollout_dataset["valid"])
    # Use a separate Random instance with fixed seed for reproducible validation indices
    val_rng = random.Random(42)
    indices = val_rng.sample(
        range(valid_dataset_len), min(num_val_indices, valid_dataset_len)
    )
    print("Validation indices", indices)
    if "euler" in cfg.dataset.name:
        num_test_indices = 1000
        cfg.val_eval.start = 0
    elif "rayleigh" in cfg.dataset.name:
        num_test_indices = 175
        cfg.val_eval.start = 5
    elif "shear_flow" in cfg.dataset.name:
        num_test_indices = 112
        cfg.val_eval.start = 5
    else:
        raise ValueError(
            f"No default number of test indices for dataset {cfg.dataset.name}."
        )
    if not cfg["debug"]:
        test_dataset_len = len(rollout_dataset["test"])
    else:
        test_dataset_len = 3
    test_rng = random.Random(123)
    test_indices = test_rng.sample(
        range(test_dataset_len), min(num_test_indices, test_dataset_len)
    )
    print("Test indices", test_indices)
    record = cfg.val_eval.record

    for epoch_idx, _ in enumerate(epochs):
        ## Train
        denoiser.train()

        losses, grads = [], []

        for i in range(
            cfg.train.accumulation * cfg.train.epoch_size // cfg.train.batch_size
        ):
            x, label = get_well_inputs(next(train_loader), device=device)
            x = preprocess(x)
            x = rearrange(x, "B L ... C -> B C L ...")

            mask = random_context_mask(x, **cfg.trajectory.context)

            if (i + 1) % cfg.train.accumulation == 0:
                loss = denoiser_loss(denoiser, x, mask=mask, label=label)
                loss_acc = loss / cfg.train.accumulation
                loss_acc.backward()

                grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip)
                grads.append(grad_norm)
                counter["update_samples"] += batch_size
                counter["update_steps"] += 1
            else:
                if use_ddp:
                    with denoiser.no_sync():
                        loss = denoiser_loss(denoiser, x, mask=mask, label=label)
                        loss_acc = loss / cfg.train.accumulation
                        loss_acc.backward()
                else:
                    loss = denoiser_loss(denoiser, x, mask=mask, label=label)
                    loss_acc = loss / cfg.train.accumulation
                    loss_acc.backward()

            losses.append(loss.detach())

        losses = torch.stack(losses)
        grads = torch.stack(grads)

        if use_ddp:
            if rank == 0:
                losses_list = [torch.empty_like(losses) for _ in range(world_size)]
                grads_list = [torch.empty_like(grads) for _ in range(world_size)]
            else:
                losses_list = None
                grads_list = None

            dist.gather(losses, losses_list, dst=0)
            dist.gather(grads, grads_list, dst=0)

            if rank == 0:
                losses = torch.cat(losses_list).cpu()
                grads = torch.cat(grads_list).cpu()
        else:
            losses = losses.cpu()
            grads = grads.cpu()

        if rank == 0:
            logs = {}
            logs["train/loss/mean"] = losses.mean().item()
            logs["train/loss/std"] = losses.std().item()
            logs["train/grad_norm/mean"] = grads.mean().item()
            logs["train/grad_norm/std"] = grads.std().item()
            logs["train/samples"] = counter["update_samples"]
            if surrogate is not None:
                # Add learning rate info to logs
                lrs = scheduler.get_last_lr()
                logs["train/lr_common"] = lrs[0]  # Common parameters lr
                if len(lrs) > 1:
                    logs["train/lr_new"] = lrs[1]  # New parameters lr
            else:
                logs["train/learning_rate"] = optimizer.param_groups[0]["lr"]

        if use_ddp:
            del losses_list, grads_list
        del losses, grads

        # Rollout eval
        denoiser.eval()
        if (
            (epoch_idx % cfg.val_eval.interval == 0)
            and cfg.val_eval.enabled
            and epoch_idx != 0
        ) or (epoch_idx == cfg.train.epochs - 1):
            if rank == 0:
                print(f"Epoch {epoch_idx}, performing validation rollout evaluation...")

            if epoch_idx == cfg.train.epochs - 1:
                rollout_splits = ["valid", "test"]
            else:
                rollout_splits = ["valid"]

            for split in rollout_splits:
                if rank == 0:
                    print(f"Rollout evaluation on {split} set...")

                # Initialize metrics dictionary with field-specific and aggregate tracking
                all_metrics = {}
                # Get number of fields
                field_names = cfg.dataset.fields
                num_fields = len(field_names)

                # Get ensemble sizes to track (only the ones we'll save)
                ensemble_sizes_to_save = getattr(
                    cfg.val_eval,
                    "ensemble_sizes_to_save",
                    list(range(1, cfg.val_eval.samples + 1)),
                )

                if cfg.val_eval.samples not in ensemble_sizes_to_save:
                    ensemble_sizes_to_save.append(cfg.val_eval.samples)
                    ensemble_sizes_to_save = sorted(ensemble_sizes_to_save)
                assert (
                    max(ensemble_sizes_to_save) <= cfg.val_eval.samples
                ), "Ensemble sizes to save should not exceed the number of samples."

                # Initialize metrics for each field and aggregate
                metric_names = [
                    "vrmse",
                    "rmse",
                    "nrmse",
                    "spread",
                    "spread_skill",
                    "rmse_p",
                    "rmse_p_low",
                    "rmse_p_mid",
                    "rmse_p_high",
                    "rmse_p_sub",
                    "crps",
                ]

                for metric_name in metric_names:
                    # Only track ensemble sizes we'll save
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            all_metrics[f"{metric_name}{ensemble_idx}_all"] = []
                            for field in range(num_fields):
                                all_metrics[
                                    f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                ] = []

                current_indices = indices if split == "valid" else test_indices
                for idx, index in enumerate(
                    tqdm(current_indices, desc="Rollouts", ncols=88, ascii=True)
                ):
                    # Video plotting condition
                    video_plotting = (record > 0) and (
                        (idx in [0, 1, 2]) or (index in [82, 30, 170])
                    )
                    # Get data
                    x, label = get_well_inputs(
                        rollout_dataset[split][index], device=device
                    )
                    x = x[
                        max(
                            0,
                            cfg.val_eval.start
                            - (cfg.val_eval.context - 1) * cfg.trajectory.stride,
                        ) :: cfg.trajectory.stride
                    ]
                    x = rollout_preprocess(x)
                    x = rearrange(x, "L ... C -> C L ...")

                    with torch.no_grad():
                        z = encode_traj(autoencoder, x)

                    compression = x.numel() / z.numel()

                    ## Emulate
                    if hasattr(cfg, "denoiser"):
                        method = "diffusion"
                        settings = f"{cfg.val_eval.sampling.algorithm}{cfg.val_eval.sampling.steps}"

                        emulate = lambda mask, z_obs, noise, i: emulate_diffusion(
                            denoiser,
                            mask,
                            z_obs,
                            label=label,  # noqa: B023
                            **cfg.val_eval.sampling,
                        )
                        if surrogate is not None:
                            emulate_orig = (
                                lambda mask, z_obs, noise, i: emulate_surrogate(
                                    surrogate,
                                    mask,
                                    z_obs,
                                    label=label,  # noqa: B023
                                )
                            )
                        else:
                            emulate_orig = None
                    elif hasattr(cfg, "surrogate"):
                        method = "surrogate"
                        settings = None
                        emulate = lambda mask, z_obs, noise, i: emulate_surrogate(
                            surrogate,
                            mask,
                            z_obs,
                            label=label,  # noqa: B023
                        )
                        emulate_orig = lambda mask, z_obs, noise, i: emulate_surrogate(
                            surrogate,
                            mask,
                            z_obs,
                            label=label,  # noqa: B023
                        )
                    else:
                        method = "autoencoder"
                        settings = None

                    with torch.no_grad():
                        if method in ("diffusion", "surrogate", "flow_matching"):
                            # Emulate trajectory
                            z_hat = emulate_rollout(
                                emulate,
                                z,
                                window=cfg.trajectory.length,
                                rollout=z.shape[1],
                                context=cfg.val_eval.context,
                                overlap=cfg.val_eval.overlap,
                                batch=cfg.val_eval.samples,
                                crps_noise_emb=0,
                            )
                            if video_plotting and surrogate is not None:
                                z_hat_surrogate = emulate_rollout(
                                    emulate_orig,
                                    z,
                                    window=cfg.trajectory.length,
                                    rollout=z.shape[1],
                                    context=cfg.val_eval.context,
                                    overlap=cfg.val_eval.overlap,
                                    batch=1,
                                    crps_noise_emb=0,
                                )
                            else:
                                z_hat_surrogate = None
                        else:
                            z_hat = z.expand(cfg.val_eval.samples, *z.shape)

                        if "euler" in cfg.dataset.name:
                            chunk_size = 128
                        elif "gravity" in cfg.dataset.name:
                            chunk_size = 128
                        elif "shear_flow" in cfg.dataset.name:
                            chunk_size = 128
                        else:
                            chunk_size = 256

                        x_hat = decode_traj(
                            autoencoder,
                            z_hat,
                            batched=True,
                            noisy=False,
                            chunk_size=chunk_size,
                        )
                        if z_hat_surrogate is not None:
                            x_hat_surrogate = decode_traj(
                                autoencoder,
                                z_hat_surrogate,
                                batched=True,
                                noisy=False,
                                chunk_size=chunk_size,
                            )
                        else:
                            x_hat_surrogate = None

                    del z_hat, z_hat_surrogate

                    ## Postprocess
                    x = postprocess(x, dim=-spatial - 2)
                    x_hat = postprocess(x_hat, dim=-spatial - 2)
                    if x_hat_surrogate is not None:
                        x_hat_surrogate = postprocess(x_hat_surrogate, dim=-spatial - 2)

                    if split == "test" and rank == 0:
                        lines = []
                    # Compute metrics
                    for field in range(x.shape[0]):
                        for t in range(cfg.val_eval.context - 1, x.shape[1]):
                            true_t = (
                                t - cfg.val_eval.context + 1
                            ) * cfg.trajectory.stride
                            u, v = x[field, t], x_hat[:, field, t]

                            # Moments (these don't change with ensemble size)
                            m1 = torch.mean(u)
                            m2 = torch.mean(u**2)

                            # Fourier analysis (ground truth doesn't change)
                            p_u, k = isotropic_power_spectrum(u, spatial=spatial)
                            bins = torch.logspace(k[0].log2(), -1.0, steps=4, base=2)

                            # Storage for ensemble-dependent metrics
                            metrics_by_ensemble = {}

                            # Iterate over ensemble sizes
                            for ensemble_i in ensemble_sizes_to_save:
                                v_subset = v[
                                    :ensemble_i
                                ]  # Use first ensemble_i samples

                                # Spread
                                if ensemble_i > 1:
                                    # see https://doi.org/10.1175/JHM-D-14-0008.1
                                    spread = torch.mean(
                                        torch.square(
                                            v_subset - torch.mean(v_subset, dim=0)
                                        )
                                    )
                                    spread = torch.sqrt(
                                        (ensemble_i + 1) / (ensemble_i - 1) * spread
                                    )
                                else:
                                    spread = 0.0

                                # Skill metrics
                                se = torch.square(u - torch.mean(v_subset, dim=0))
                                mse = torch.mean(se)
                                rmse = torch.sqrt(mse)
                                nrmse = torch.sqrt(mse / (torch.mean(u**2) + 1e-6))
                                vrmse = torch.sqrt(mse / (torch.var(u) + 1e-6))

                                # Spread_skill ratio
                                spread_skill = (spread + 1e-3) / (rmse + 1e-3)

                                # Fourier metrics
                                p_v, _ = isotropic_power_spectrum(
                                    v_subset, spatial=spatial
                                )
                                p_v = torch.mean(p_v, dim=0)
                                se_p = torch.square(1 - (p_v + 1e-6) / (p_u + 1e-6))
                                rmse_p = torch.sqrt(torch.mean(se_p))

                                fourier_extras = []
                                for i in range(4):
                                    if i < 3:
                                        mask = torch.logical_and(
                                            bins[i] <= k, k <= bins[i + 1]
                                        )
                                    else:
                                        mask = bins[i] <= k
                                    fourier_extras.append(
                                        torch.sqrt(torch.mean(se_p[mask])).item()
                                    )

                                extras = []
                                ## Wasserstein
                                w_uv = ot.lp.wasserstein_1d(
                                    u.flatten(),
                                    v_subset.flatten(),
                                    p=1.0,
                                )
                                extras.append(w_uv.item())
                                ## Sliced EMD (only makes sense for density)
                                extras.append(None)

                                # CRPS
                                crps_i = crps_loss(
                                    predictions=v_subset[None, :, None, None, :],
                                    target=u[None, None, None, :],
                                    mask=None,
                                    mem_efficient=False,
                                ).mean()

                                # Store all metrics for this ensemble size
                                metrics_by_ensemble[ensemble_i] = {
                                    "vrmse": vrmse.item(),
                                    "rmse": rmse.item(),
                                    "nrmse": nrmse.item(),
                                    "spread": (
                                        spread
                                        if isinstance(spread, float)
                                        else spread.item()
                                    ),
                                    "spread_skill": spread_skill.item(),
                                    "rmse_p": rmse_p.item(),
                                    "rmse_p_low": fourier_extras[0],
                                    "rmse_p_mid": fourier_extras[1],
                                    "rmse_p_high": fourier_extras[2],
                                    "rmse_p_sub": fourier_extras[3],
                                    "crps": crps_i.item(),
                                    "wasserstein": extras[0],
                                    "emd": extras[1],
                                }

                            # For CSV output (test split) - store one row per ensemble size
                            if split == "test" and rank == 0:
                                # Store one row for each ensemble size
                                for ensemble_i in ensemble_sizes_to_save:
                                    if (
                                        ensemble_i in metrics_by_ensemble
                                    ):  # Only if we computed it
                                        ensemble_metrics = metrics_by_ensemble[
                                            ensemble_i
                                        ]

                                        line = (
                                            f"{runid},state,{compression:.1f},crps,{0},"
                                        )
                                        line += f"train,None,{cfg.val_eval.context},{cfg.val_eval.overlap},1.0,"
                                        line += f"{split},{index},{cfg.val_eval.start},{cfg.seed},"
                                        line += f"{field_names[field]},{(t - cfg.val_eval.context + 1) * cfg.trajectory.stride},False,"
                                        line += (
                                            f"{ensemble_i},"  # Add ensemble size here
                                        )
                                        line += f"{m1},{m2},"
                                        line += f"{ensemble_metrics['spread']},{ensemble_metrics['spread_skill']},"
                                        line += f"{ensemble_metrics['rmse']},{ensemble_metrics['nrmse']},{ensemble_metrics['vrmse']},"
                                        line += f"{ensemble_metrics['rmse_p']},"

                                        # Add Fourier extras
                                        fourier_values = [
                                            ensemble_metrics["rmse_p_low"],
                                            ensemble_metrics["rmse_p_mid"],
                                            ensemble_metrics["rmse_p_high"],
                                            ensemble_metrics["rmse_p_sub"],
                                        ]
                                        line += ",".join(map(str, fourier_values)) + ","
                                        # Add CRPS value for this ensemble size
                                        line += f"{ensemble_metrics['crps']}" + ","

                                        # Add Wasserstein and Sliced EMD
                                        extra_values = [
                                            ensemble_metrics["wasserstein"],
                                            ensemble_metrics["emd"],
                                        ]
                                        line += ",".join(map(str, extra_values))

                                        # Add label parameters
                                        if hasattr(label, "tolist"):
                                            line += (
                                                f",{','.join(map(str, label.tolist()))}"
                                            )

                                        line += "\n"
                                        lines.append(line)

                            # Store metrics for logging (per-field and field-averaged)
                            for ensemble_i in ensemble_sizes_to_save:
                                ensemble_idx = (
                                    f"_{ensemble_i}"
                                    if ensemble_i != cfg.val_eval.samples
                                    else ""
                                )

                                # Store per-field metrics for this ensemble size
                                for metric_name, value in metrics_by_ensemble[
                                    ensemble_i
                                ].items():
                                    if metric_name not in ["wasserstein", "emd"]:
                                        all_metrics[
                                            f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                        ].append((true_t, value))

                    # After processing all samples, compute field-averaged metrics across all timesteps
                    for ensemble_i in ensemble_sizes_to_save:
                        ensemble_idx = (
                            f"_{ensemble_i}"
                            if ensemble_i != cfg.val_eval.samples
                            else ""
                        )

                        for metric_name in metric_names:
                            # Get all field keys for this metric and ensemble size
                            field_keys = [
                                f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                for field in range(num_fields)
                            ]

                            # Check if all field keys exist and have data
                            if all(
                                key in all_metrics and all_metrics[key]
                                for key in field_keys
                            ):
                                # Get all timesteps from the first field (they should all be the same)
                                timesteps = [t for t, _ in all_metrics[field_keys[0]]]

                                # Create the field-averaged list
                                field_averaged_values = []

                                for i, timestep in enumerate(timesteps):
                                    # Collect values from all fields for this timestep index
                                    field_values = []
                                    for field_key in field_keys:
                                        if i < len(all_metrics[field_key]):
                                            _, value = all_metrics[field_key][
                                                i
                                            ]  # Get value at index i
                                            field_values.append(value)

                                    # Average across fields for this timestep
                                    if (
                                        len(field_values) == num_fields
                                    ):  # Ensure we have all fields
                                        avg_value = sum(field_values) / len(
                                            field_values
                                        )
                                        field_averaged_values.append(
                                            (timestep, avg_value)
                                        )

                                # Store the field-averaged values
                                all_metrics[f"{metric_name}{ensemble_idx}_all"] = (
                                    field_averaged_values
                                )

                    if split == "test" and rank == 0:
                        # Save
                        stats_path = (
                            runpath / "test_stats.csv"
                        )  # Single file for all samples
                        with FileLock(str(stats_path) + ".lock"):
                            with open(stats_path, "a") as f:  # Use write mode
                                # Only write header if file is empty/new
                                if stats_path.stat().st_size == 0:
                                    header = "run,target,compression,method,noise_emb_size,settings,guidance,context,overlap,speed,"
                                    header += (
                                        "split,index,start,seed,field,time,relative,"
                                    )
                                    header += "ensemble_size,"  # Add this column
                                    header += "m1,m2,spread,spread_skill,rmse,nrmse,vrmse,rmse_p,"
                                    header += (
                                        "rmse_p_low,rmse_p_mid,rmse_p_high,rmse_p_sub,"
                                    )
                                    header += "crps,"  # Single CRPS value per row
                                    header += "wasserstein,emd"
                                    if (
                                        hasattr(label, "tolist")
                                        and len(label.tolist()) > 0
                                    ):
                                        param_names = getattr(
                                            cfg.dataset,
                                            "param_names",
                                            [
                                                f"param_{i}"
                                                for i in range(len(label.tolist()))
                                            ],
                                        )
                                        header += f",{','.join(param_names)}"
                                    header += "\n"
                                    f.write(header)
                                f.writelines(lines)

                    if (record > 0) and (
                        (idx in [0, 1, 2]) or (index in [82, 30, 170])
                    ):
                        if x.shape[-1] < x.shape[-2]:
                            x, x_hat = x.mT, x_hat.mT
                            if x_hat_surrogate is not None:
                                x_hat_surrogate = x_hat_surrogate.mT
                        if x_hat_surrogate is not None:
                            frames = torch.stack(
                                (x, x_hat_surrogate[0], *x_hat[:record])
                            )
                        else:
                            frames = torch.stack((x, *x_hat[:record]))
                        frames = rearrange(frames, "N C L H W -> L N C H W")
                        # Get rid of the first conditioning
                        frames = frames[cfg.val_eval.context :]

                        # Ensure the directory exists before saving the movie
                        movie_dir = runpath / f"epoch_{epoch_idx}"
                        movie_dir.mkdir(parents=True, exist_ok=True)
                        draw_movie(
                            frames,
                            file=(
                                movie_dir
                                / f"{runname}_{target}_{split}_{index:06d}_{cfg.val_eval.start:03d}_{cfg.val_eval.context}_{cfg.val_eval.overlap}_{cfg.seed}.mp4"
                            ),
                            fps=4.0 / cfg.trajectory.stride,
                            isolate={2},
                        )

                        del frames
                    del x, x_hat
                # Log aggregated rollout metrics across all samples
                if rank == 0 and len(all_metrics[f"{metric_names[0]}_all"]) > 0:
                    # Helper function to split by time ranges
                    def split_by_time_range(data, ranges):
                        """Split data into time ranges. ranges is list of (min, max) tuples."""
                        splits = [[] for _ in ranges]
                        for timestep, value in data:
                            for i, (t_min, t_max) in enumerate(ranges):
                                if t_min <= timestep <= t_max:
                                    splits[i].append(value)
                                    break
                        return splits

                    # Helper functions using NumPy
                    def compute_mean(data_list):
                        values = np.array(
                            [v for idx, v in data_list if idx >= cfg.val_eval.context]
                        )
                        return float(np.mean(values))

                    def compute_median(data_list):
                        values = np.array(
                            [v for idx, v in data_list if idx >= cfg.val_eval.context]
                        )
                        return float(np.median(values))

                    # Define time ranges: 1-40, 41-80, 81-200
                    if "rayleigh" in cfg.dataset.name:
                        time_ranges = [(1, 40), (41, 80), (81, 200)]
                        range_names = ["1_40", "41_80", "81_200"]
                    elif "euler" in cfg.dataset.name:
                        time_ranges = [(1, 20), (21, 60), (61, 100)]
                        range_names = ["1_20", "21_60", "61_100"]
                    elif "shear_flow" in cfg.dataset.name:
                        time_ranges = [(1, 40), (41, 80), (81, 200)]
                        range_names = ["1_40", "41_80", "81_200"]
                    else:
                        raise ValueError(
                            f"No time ranges defined for dataset {cfg.dataset.name}."
                        )

                    # Log aggregate metrics (averaged across all fields) - for ALL ensemble sizes
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            for metric_name in metric_names:
                                metric_key = f"{metric_name}{ensemble_idx}_all"
                                if (
                                    metric_key in all_metrics
                                    and all_metrics[metric_key]
                                ):
                                    data = all_metrics[metric_key]
                                    logs[
                                        f"{split}/rollout/{metric_name}{ensemble_idx}_mean"
                                    ] = compute_mean(data)
                                    logs[
                                        f"{split}/rollout/{metric_name}{ensemble_idx}_median"
                                    ] = compute_median(data)

                                    # Time range breakdowns
                                    splits = split_by_time_range(data, time_ranges)
                                    for i, range_name in enumerate(range_names):
                                        if splits[i]:
                                            logs[
                                                f"{split}/rollout/{metric_name}{ensemble_idx}_mean_{range_name}"
                                            ] = sum(splits[i]) / len(splits[i])
                                            logs[
                                                f"{split}/rollout/{metric_name}{ensemble_idx}_median_{range_name}"
                                            ] = float(np.median(splits[i]))

                    # Log per-field metrics - for ALL ensemble sizes
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            for field in range(num_fields):
                                for metric_name in metric_names:
                                    metric_key = f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                    if (
                                        metric_key in all_metrics
                                        and all_metrics[metric_key]
                                    ):
                                        data = all_metrics[metric_key]
                                        logs[
                                            f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_mean"
                                        ] = compute_mean(data)
                                        logs[
                                            f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_median"
                                        ] = compute_median(data)

                                        # Time range breakdowns per field
                                        splits = split_by_time_range(data, time_ranges)
                                        for i, range_name in enumerate(range_names):
                                            if splits[i]:
                                                logs[
                                                    f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_mean_{range_name}"
                                                ] = sum(splits[i]) / len(splits[i])
                                                logs[
                                                    f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_median_{range_name}"
                                                ] = float(np.median(splits[i]))

        ## Eval
        denoiser.eval()

        losses = []

        with torch.no_grad():
            for _ in range(cfg.valid.epoch_size // cfg.valid.batch_size):
                x, label = get_well_inputs(next(valid_loader), device=device)
                x = preprocess(x)
                x = rearrange(x, "B L ... C -> B C L ...")

                mask = random_context_mask(x, **cfg.trajectory.context)

                loss = denoiser_loss(denoiser, x, mask=mask, label=label)
                losses.append(loss)

        losses = torch.stack(losses)

        if use_ddp:
            if rank == 0:
                losses_list = [torch.empty_like(losses) for _ in range(world_size)]
            else:
                losses_list = None

            dist.gather(losses, losses_list, dst=0)

            if rank == 0:
                losses = torch.cat(losses_list).cpu()
        else:
            losses = losses.cpu()

        if rank == 0:
            logs["valid/loss/mean"] = losses.mean().item()
            logs["valid/loss/std"] = losses.std().item()

            epochs.set_postfix(
                lt=logs["train/loss/mean"],
                lv=logs["valid/loss/mean"],
            )

            run.log(logs, step=counter["epoch"])
            counter["epoch"] += 1

        if use_ddp:
            del losses_list
        del losses

        ## LR scheduler
        scheduler.step()

        ## Checkpoint
        if rank == 0:
            if use_ddp:
                state = denoiser.module.state_dict()
            else:
                state = denoiser.state_dict()

            torch.save(state, runpath / "state.pth")

            if logs["valid/loss/mean"] < best_valid_loss:
                best_valid_loss = logs["valid/loss/mean"]

                torch.save(state, runpath / "state_best.pth")

            del state

        if use_ddp:
            dist.barrier(device_ids=[device_id])

    # W&B
    if rank == 0:
        run.finish()

    # DDP
    if use_ddp:
        dist.destroy_process_group()


if __name__ == "__main__":
    # Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("overrides", nargs="*", type=str)
    parser.add_argument(
        "--single_gpu", action="store_true", help="Run on a single GPU (for debugging)"
    )
    parser.add_argument("--debug", action="store_true", help="Debug mode")

    args = parser.parse_args()

    # Config
    if args.debug:
        cfg = compose(
            config_file="./experiments/configs/train_diffusion_debug.yaml",
            overrides=args.overrides,
        )
        wandb_project = "lola_diffusion_debug"
    else:
        cfg = compose(
            config_file="./experiments/configs/train_diffusion.yaml",
            overrides=args.overrides,
        )
        wandb_project = "lola_prob"

    # Job
    runid = wandb.util.generate_id()
    if args.debug:
        runid = "debug_" + runid
        with open_dict(cfg):
            cfg["debug"] = True
    else:
        runid = "diff_" + runid
        with open_dict(cfg):
            cfg["debug"] = False

    if args.single_gpu:
        # Run directly on single GPU without SLURM
        print(f"Running on single GPU without SLURM, runid: {runid}")
        train(runid, cfg, wandb_project=wandb_project)
    else:
        print("Scheduling SLURM job...")
        if cfg.compute.nodes > 1:
            interpreter = f"torchrun --nnodes {cfg.compute.nodes} --nproc-per-node {cfg.compute.gpus} --rdzv_backend=c10d --rdzv_endpoint=$SLURMD_NODENAME:12345 --rdzv_id=$SLURM_JOB_ID"
        else:
            interpreter = (
                f"torchrun --nnodes 1 --nproc-per-node {cfg.compute.gpus} --standalone"
            )

        job_config = dawgz.job(
            f=partial(train, runid, cfg, wandb_project),
            name=f"crps {runid}",
            nodes=cfg.compute.nodes,
            cpus=cfg.compute.cpus_per_gpu * cfg.compute.gpus,
            gpus=cfg.compute.gpus,
            ram=cfg.compute.ram,
            time=cfg.compute.time,
            partition=cfg.server.partition,
            constraint=cfg.server.constraint,
            exclude=cfg.server.exclude,
        )

        print(f"Job config created: {job_config}")

        if args.debug:
            backend = "async"
        else:
            backend = "slurm"
        dawgz.schedule(
            job_config,
            name=f"training crps {runid}",
            # backend="slurm",
            backend=backend,
            interpreter=interpreter,
            env=[
                "export OMP_NUM_THREADS=" + f"{cfg.compute.cpus_per_gpu}",
                "export WANDB_SILENT=true",
                "export XDG_CACHE_HOME=$HOME/.cache",
                "export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
            ],
        )
        print("Job scheduled successfully")
