#!/usr/bin/env python

import argparse
import random
import time as timing
from functools import partial
from pathlib import Path

import dawgz
import numpy as np
import ot
from filelock import FileLock
from omegaconf import DictConfig, open_dict
from tqdm import tqdm

import wandb
from lola.autoencoder import get_autoencoder
from lola.crps import CRPS, get_ensemble_predictions
from lola.data import field_postprocess
from lola.emulation import (
    decode_traj,
    emulate_crps,
    emulate_rollout,
    emulate_surrogate,
    encode_traj,
)
from lola.fourier import isotropic_power_spectrum
from lola.hydra import compose
from lola.optim import get_staged_optimizer
from lola.plot import draw_movie
from lola.surrogate import get_surrogate
from lola.utils import load_common_weights


def train(
    runid: str,
    cfg: DictConfig,
    wandb_project: str = "lola_prob",
    target: str = "state",
):
    import os
    import re
    from pathlib import Path

    import torch
    import torch.distributed as dist
    from einops import rearrange
    from omegaconf import OmegaConf, open_dict
    from torch.nn.parallel import DistributedDataParallel
    from tqdm import trange

    import wandb
    from lola.data import (
        MiniWellDataset,
        field_preprocess,
        find_hdf5,
        get_dataloader,
        get_well_inputs,
        get_well_multi_dataset,
    )
    from lola.emulation import random_context_mask
    from lola.nn.utils import load_state_dict
    from lola.optim import get_optimizer, safe_gd_step
    from lola.utils import randseed

    # Add this at the very beginning
    print(f"Starting train function with runid: {runid}")
    print(
        f"Environment variables: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}"
    )
    print(
        f"CUDA available: {torch.cuda.is_available()}, device count: {torch.cuda.device_count()}"
    )

    # Check if we're running in a distributed environment
    use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ

    if use_ddp:
        print("Initializing DDP...")
        # DDP
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        device_id = os.environ.get("LOCAL_RANK", rank)
        device_id = int(device_id)
        print(
            f"DDP initialized: rank={rank}, world_size={world_size}, device_id={device_id}"
        )
    else:
        # Single GPU
        rank = 0
        world_size = 1
        device_id = 0
        print("Single GPU mode")

    device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device(device)

    # Performance
    torch.set_float32_matmul_precision("high")

    # Config
    batch_size = cfg.train.batch_size
    assert cfg.train.epoch_size % batch_size == 0
    assert batch_size % (cfg.train.accumulation * world_size) == 0
    assert cfg.valid.epoch_size % cfg.valid.batch_size == 0
    assert cfg.valid.batch_size % world_size == 0

    # Load surrogate model if fine-tuning from a surrogate
    surrogate_run_path = None
    if cfg.surrogate_run:
        surrogate_run_path = cfg.surrogate_run

    if surrogate_run_path:
        surrogate_runpath = Path(surrogate_run_path)
        surrogate_cfg = OmegaConf.load(surrogate_runpath / "config.yaml")
        state = torch.load(
            surrogate_runpath / f"{target}.pth", weights_only=True, map_location=device
        )
        new_state_dict = {}
        for k, v in state.items():
            name = k.replace("module.", "") if k.startswith("module.") else k
            new_state_dict[name] = v
        if hasattr(surrogate_cfg, "surrogate"):
            print("Loading surrogate model...")
            surrogate = get_surrogate(**surrogate_cfg.surrogate)
            surrogate.load_state_dict(new_state_dict)
            surrogate.to(device)
            surrogate.eval()
        else:
            raise ValueError(
                f"Could not find surrogate configuration in the provided run {surrogate_runpath}."
            )

        # Also load autoencoder
        if (surrogate_runpath / "autoencoder").exists():
            ae_cfg = OmegaConf.load(surrogate_runpath / "autoencoder/config.yaml").ae

            state = torch.load(
                surrogate_runpath / "autoencoder/state.pth",
                weights_only=True,
                map_location=device,
            )

            autoencoder = get_autoencoder(**ae_cfg)
            autoencoder.load_state_dict(state)
            autoencoder.to(device)
            autoencoder.requires_grad_(False)
            autoencoder.eval()
        else:
            autoencoder = None

        if surrogate_cfg.ae_run:
            space = re.search(r"f\d+c\d+", surrogate_cfg.ae_run).group()
        else:
            space = "pixel"
    else:
        # Start training from scratch
        surrogate = None

        # Load autoencoder
        # Autoencoder
        if cfg.load_ae and Path(cfg.ae_run).exists():
            ae_cfg = OmegaConf.load(Path(cfg.ae_run) / "config.yaml").ae

            state = torch.load(
                Path(cfg.ae_run) / "state.pth",
                weights_only=True,
                map_location=device,
            )

            autoencoder = get_autoencoder(**ae_cfg)
            autoencoder.load_state_dict(state)
            autoencoder.to(device)
            autoencoder.requires_grad_(False)
            autoencoder.eval()
        else:
            autoencoder = None

        if cfg.ae_run:
            space = re.search(r"f\d+c\d+", cfg.ae_run).group()
        else:
            space = "pixel"

    if cfg.load_surrogate:
        runname = f"{runid}_{cfg.dataset.name}_NS{cfg.ensemble_size}_{space}_{cfg.surrogate.name}_{cfg.noise_emb_features}_{cfg.staged_training.threshold_lr}_{cfg.optim.learning_rate}_CRPS"
    else:
        assert surrogate is None
        runname = f"scratch_{runid}_{cfg.dataset.name}_NS{cfg.ensemble_size}_{space}_{cfg.noise_emb_features}_{cfg.optim.learning_rate}_CRPS"

    runpath = Path(f"{cfg.server.storage}/runs/crps/{runname}")
    runpath = runpath.expanduser().resolve()
    runpath.mkdir(parents=True, exist_ok=True)

    with open_dict(cfg):
        cfg.name = runname
        cfg.path = str(runpath)
        cfg.seed = randseed(runid)

        if surrogate is not None:
            if surrogate_cfg.ae_run:
                ae_run_path = surrogate_cfg.ae_run
                cfg.ae_run = os.path.realpath(
                    os.path.expanduser(ae_run_path), strict=True
                )
        else:
            if cfg.ae_run:
                cfg.ae_run = os.path.realpath(
                    os.path.expanduser(cfg.ae_run), strict=True
                )

    if rank == 0 and cfg.ae_run:
        os.symlink(cfg.ae_run, runpath / "autoencoder")

    if use_ddp:
        dist.barrier(device_ids=[device_id])

    # Stem
    if cfg.fork.run is None:
        counter = {
            "epoch": 0,
            "update_samples": 0,
            "update_steps": 0,
        }
    else:
        stem = wandb.Api().run(path=cfg.fork.run)
        stem_name = Path(stem.config["path"]).name
        stem_path = Path(f"{cfg.server.storage}/runs/crps/{stem_name}")
        stem_path = stem_path.expanduser().resolve()
        stem_state = torch.load(
            stem_path / f"{cfg.fork.target}.pth", weights_only=True, map_location=device
        )

        counter = {
            "epoch": stem.summary["_step"] + 1,
            "update_samples": stem.summary["train/samples"],
            "update_steps": stem.summary["train/update_steps"],
        }

    # Data
    if cfg.ae_run:
        files = {
            split: [
                file
                for physic in cfg.dataset.physics
                for file in find_hdf5(
                    path=runpath / "autoencoder/cache" / physic / split,
                    include_filters=cfg.dataset.include_filters,
                )
            ]
            for split in ("train", "valid")
        }

        dataset = {
            split: MiniWellDataset.from_files(
                files=files[split],
                steps=cfg.trajectory.length,
                stride=cfg.trajectory.stride,
            )
            for split in ("train", "valid")
        }
    else:
        dataset = {
            split: get_well_multi_dataset(
                path=cfg.server.datasets,
                physics=cfg.dataset.physics,
                split=split,
                steps=cfg.trajectory.length,
                min_dt_stride=cfg.trajectory.stride,
                max_dt_stride=cfg.trajectory.stride,
                include_filters=cfg.dataset.include_filters,
                augment=cfg.dataset.augment,
            )
            for split in ("train", "valid")
        }

    # Get rollout dataset for evaluation
    rollout_dataset = {
        "valid": get_well_multi_dataset(
            path=cfg.server.datasets,
            physics=cfg.dataset.physics,
            split="valid",
            steps=-1,
            include_filters=cfg.dataset.include_filters,
            augment=[s for s in cfg.dataset.augment if "random" not in s],
        ),
        "test": get_well_multi_dataset(
            path=cfg.server.datasets,
            physics=cfg.dataset.physics,
            split="test",
            steps=-1,
            include_filters=cfg.dataset.include_filters,
            augment=[s for s in cfg.dataset.augment if "random" not in s],
        ),
    }

    train_loader, valid_loader = [
        get_dataloader(
            dataset=dataset[split],
            batch_size=(
                batch_size // cfg.train.accumulation // world_size
                if split == "train"
                else cfg.valid.batch_size // world_size
            ),
            shuffle=True,
            infinite=True,
            num_workers=cfg.compute.cpus_per_gpu,
            rank=rank,
            world_size=world_size,
            seed=cfg.seed,
        )
        for split in ("train", "valid")
    ]

    if cfg.ae_run:
        preprocess = lambda x: x
    else:
        preprocess = partial(
            field_preprocess,
            mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
            std=torch.as_tensor(cfg.dataset.stats.std, device=device),
            transform=cfg.dataset.transform,
        )

    rollout_preprocess = partial(
        field_preprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    postprocess = partial(
        field_postprocess,
        mean=torch.as_tensor(cfg.dataset.stats.mean, device=device),
        std=torch.as_tensor(cfg.dataset.stats.std, device=device),
        transform=cfg.dataset.transform,
    )

    if hasattr(cfg.dataset, "dimensions"):
        spatial = len(cfg.dataset.dimensions)
    else:
        spatial = 2

    x, label = get_well_inputs(next(valid_loader))
    x = rearrange(x, "B L ... C -> B C L ...")

    # Model, optimizer & scheduler
    with open_dict(cfg):
        cfg.surrogate.channels = x.shape[1]
        cfg.surrogate.label_features = label.shape[1]
        cfg.surrogate.noise_emb_features = cfg.noise_emb_features
        cfg.surrogate.spatial = len(cfg.dataset.dimensions) + 1

    ft_surrogate = get_surrogate(**cfg.surrogate).to(device)
    num_params = sum(p.numel() for p in ft_surrogate.parameters())
    print("Number of parameters:", num_params)
    if surrogate is not None:
        num_params_initial = sum(p.numel() for p in surrogate.parameters())
        print("Number of parameters initial surrogate:", num_params_initial)

    # Load common weights
    if cfg.finetune and surrogate_ft_start is not None:
        print("Loading weights from CRPS fine-tuning start surrogate...")
        model_info = load_common_weights(
            ft_surrogate,
            surrogate_ft_start.state_dict(),
            strict=True,
        )
        del surrogate_ft_start
    elif surrogate is not None:
        print("Loading weights from surrogate...")
        model_info = load_common_weights(
            ft_surrogate,
            surrogate.state_dict(),
            strict=True,
        )
    surrogate_loss = CRPS()

    if cfg.fork.run is not None:
        load_state_dict(ft_surrogate, stem_state, strict=cfg.fork.strict)
        del stem_state

    if use_ddp:
        ft_surrogate = DistributedDataParallel(
            module=ft_surrogate,
            device_ids=[device_id],
        )
        model_for_optim = ft_surrogate.module
    else:
        ft_surrogate = ft_surrogate
        model_for_optim = ft_surrogate

    if surrogate is not None:
        # Get surrogate state dict for identifying common parameters
        surrogate_state = surrogate.state_dict()

        # Create staged optimizer and scheduler
        staged_cfg = cfg.get("staged_training", {})
        threshold_lr = staged_cfg.get("threshold_lr", 1e-5)
        common_params_kwargs = staged_cfg.get("common_params_kwargs", {})
        new_params_kwargs = staged_cfg.get("new_params_kwargs", {})
        common_scheduler = staged_cfg.get("common_scheduler", "cosine")
        common_warmup = staged_cfg.get("common_warmup", 0)
        new_scheduler = staged_cfg.get("new_scheduler", "cosine")
        new_warmup = staged_cfg.get("new_warmup", 0)

        optimizer, scheduler = get_staged_optimizer(
            model=model_for_optim,
            surrogate_state_dict=surrogate_state,
            threshold_lr=threshold_lr,
            common_params_kwargs=common_params_kwargs,
            new_params_kwargs=new_params_kwargs,
            common_scheduler=common_scheduler,
            common_warmup=common_warmup,
            new_scheduler=new_scheduler,
            new_warmup=new_warmup,
            epochs=cfg.train.epochs,
            **cfg.optim,
        )
    else:
        optimizer, scheduler = get_optimizer(
            params=ft_surrogate.parameters(),
            epochs=cfg.train.epochs,
            **cfg.optim,
        )

    # W&B
    if rank == 0:
        OmegaConf.save(cfg, runpath / "config.yaml")

    if rank == 0:
        run = wandb.init(
            entity=cfg.wandb.entity,
            project=wandb_project,
            group=cfg.dataset.name,
            id=runid,
            name=runname,
            config=OmegaConf.to_container(cfg),
        )

    # Training loop
    if rank == 0:
        epochs = trange(cfg.train.epochs, ncols=88, ascii=True)
    else:
        epochs = range(cfg.train.epochs)

    best_valid_loss = float("inf")
    start_time = timing.time()
    epoch_times = []

    # For validation on the same set of trajectories - Get a random set of indices up to the length of valid_dataset
    num_val_indices = getattr(cfg.val_eval, "num_val_indices", 20)
    valid_dataset_len = len(rollout_dataset["valid"])
    # Use a separate Random instance with fixed seed for reproducible validation indices
    val_rng = random.Random(42)
    indices = val_rng.sample(
        range(valid_dataset_len), min(num_val_indices, valid_dataset_len)
    )
    print("Validation indices", indices)
    if "euler" in cfg.dataset.name:
        num_test_indices = 1000
        cfg.val_eval.start = 0
    elif "rayleigh" in cfg.dataset.name:
        num_test_indices = 175
        cfg.val_eval.start = 5
    elif "shear_flow" in cfg.dataset.name:
        num_test_indices = 112
        cfg.val_eval.start = 5
    else:
        raise ValueError(
            f"No default number of test indices for dataset {cfg.dataset.name}."
        )
    if not cfg["debug"]:
        test_dataset_len = len(rollout_dataset["test"])
    else:
        test_dataset_len = 3
    test_rng = random.Random(123)
    test_indices = test_rng.sample(
        range(test_dataset_len), min(num_test_indices, test_dataset_len)
    )
    print("Test indices", test_indices)
    record = cfg.val_eval.record

    for epoch_idx, _ in enumerate(epochs):
        epoch_start = timing.time()

        ## Train
        ft_surrogate.train()

        losses, grads = [], []

        for i in range(cfg.train.epoch_size // (batch_size // cfg.train.accumulation)):
            x, label = get_well_inputs(next(train_loader), device=device)
            x = preprocess(x)
            x = rearrange(x, "B L ... C -> B C L ...")

            mask = random_context_mask(x, **cfg.trajectory.context)

            if (i + 1) % cfg.train.accumulation == 0:
                preds = get_ensemble_predictions(
                    ft_surrogate,
                    inputs=x,
                    mask=mask,
                    label=label,
                    noise_emb_features=cfg.noise_emb_features,
                    n_samples=cfg.ensemble_size,
                    device=device,
                )

                loss = surrogate_loss(
                    predictions=preds,
                    target=x,
                    mask=mask,
                    mem_efficient=False,
                )

                loss_acc = loss.mean() / cfg.train.accumulation
                loss_acc.backward()

                grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip)
                grads.append(grad_norm)

                counter["update_samples"] += batch_size
                counter["update_steps"] += 1
            else:
                if use_ddp:
                    with ft_surrogate.no_sync():
                        preds = get_ensemble_predictions(
                            ft_surrogate,
                            inputs=x,
                            mask=mask,
                            label=label,
                            noise_emb_features=cfg.noise_emb_features,
                            n_samples=cfg.ensemble_size,
                            device=device,
                        )

                        loss = surrogate_loss(
                            predictions=preds,
                            target=x,
                            mask=mask,
                            mem_efficient=False,
                        )

                        loss_acc = loss.mean() / cfg.train.accumulation
                        loss_acc.backward()
                else:
                    preds = get_ensemble_predictions(
                        ft_surrogate,
                        inputs=x,
                        mask=mask,
                        label=label,
                        noise_emb_features=cfg.noise_emb_features,
                        n_samples=cfg.ensemble_size,
                        device=device,
                    )

                    loss = surrogate_loss(
                        predictions=preds,
                        target=x,
                        mask=mask,
                        mem_efficient=False,
                    )
                    loss_acc = loss.mean() / cfg.train.accumulation
                    loss_acc.backward()

            losses.append(loss.detach())

        losses = torch.stack(losses)
        grads = torch.stack(grads)

        # Gather losses and grads only if using DDP
        if use_ddp:
            if rank == 0:
                losses_list = [torch.empty_like(losses) for _ in range(world_size)]
                grads_list = [torch.empty_like(grads) for _ in range(world_size)]
            else:
                losses_list = None
                grads_list = None

            dist.gather(losses, losses_list, dst=0)
            dist.gather(grads, grads_list, dst=0)

            if rank == 0:
                losses = torch.cat(losses_list).cpu()
                grads = torch.cat(grads_list).cpu()
        else:
            losses = losses.cpu()
            grads = grads.cpu()

        if rank == 0:
            logs = {}
            logs["train/loss/mean"] = losses.mean().item()
            logs["train/loss/std"] = losses.std().item()
            logs["train/grad_norm/mean"] = grads.mean().item()
            logs["train/grad_norm/std"] = grads.std().item()
            logs["train/update_steps"] = counter["update_steps"]
            logs["train/samples"] = counter["update_samples"]
            if surrogate is not None:
                # Add learning rate info to logs
                lrs = scheduler.get_last_lr()
                logs["train/lr_common"] = lrs[0]  # Common parameters lr
                if len(lrs) > 1:
                    logs["train/lr_new"] = lrs[1]  # New parameters lr
            else:
                logs["train/learning_rate"] = optimizer.param_groups[0]["lr"]

        if use_ddp:
            del losses_list, grads_list
        del losses, grads

        ## Eval
        ft_surrogate.eval()

        ## Stats for rollout
        if (
            (epoch_idx % cfg.val_eval.interval == 0)
            and cfg.val_eval.enabled
            and epoch_idx != 0
        ) or (epoch_idx == cfg.train.epochs - 1):
            if rank == 0:
                print(f"Epoch {epoch_idx}, performing validation rollout evaluation...")

            if epoch_idx == cfg.train.epochs - 1:
                rollout_splits = ["valid", "test"]
            else:
                rollout_splits = ["valid"]

            for split in rollout_splits:
                if rank == 0:
                    print(f"Rollout evaluation on {split} set...")

                # Initialize metrics dictionary with field-specific and aggregate tracking
                all_metrics = {}
                # Get number of fields
                field_names = cfg.dataset.fields
                num_fields = len(field_names)

                # Get ensemble sizes to track (only the ones we'll save)
                ensemble_sizes_to_save = getattr(
                    cfg.val_eval,
                    "ensemble_sizes_to_save",
                    list(range(1, cfg.val_eval.samples + 1)),
                )

                if cfg.val_eval.samples not in ensemble_sizes_to_save:
                    ensemble_sizes_to_save.append(cfg.val_eval.samples)
                    ensemble_sizes_to_save = sorted(ensemble_sizes_to_save)
                assert (
                    max(ensemble_sizes_to_save) <= cfg.val_eval.samples
                ), "Ensemble sizes to save should not exceed the number of samples."

                # Initialize metrics for each field and aggregate
                metric_names = [
                    "vrmse",
                    "rmse",
                    "nrmse",
                    "spread",
                    "spread_skill",
                    "rmse_p",
                    "rmse_p_low",
                    "rmse_p_mid",
                    "rmse_p_high",
                    "rmse_p_sub",
                    "crps",
                ]

                for metric_name in metric_names:
                    # Only track ensemble sizes we'll save
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            all_metrics[f"{metric_name}{ensemble_idx}_all"] = []
                            for field in range(num_fields):
                                all_metrics[
                                    f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                ] = []

                current_indices = indices if split == "valid" else test_indices

                for idx, index in enumerate(
                    tqdm(current_indices, desc="Rollouts", ncols=88, ascii=True)
                ):
                    # Video plotting condition
                    video_plotting = (record > 0) and (
                        (idx in [0, 1, 2]) or (index in [82, 30, 170])
                    )
                    # Get data
                    x, label = get_well_inputs(
                        rollout_dataset[split][index], device=device
                    )
                    x = x[
                        max(
                            0,
                            cfg.val_eval.start
                            - (cfg.val_eval.context - 1) * cfg.trajectory.stride,
                        ) :: cfg.trajectory.stride
                    ]
                    x = rollout_preprocess(x)
                    x = rearrange(x, "L ... C -> C L ...")

                    with torch.no_grad():
                        z = encode_traj(autoencoder, x)

                    compression = x.numel() / z.numel()

                    if hasattr(cfg, "surrogate"):
                        method = "surrogate"
                        emulate = lambda mask, z_obs, noise, i: emulate_crps(
                            ft_surrogate,
                            mask,
                            z_obs,
                            label=label,
                            noise=noise,
                        )
                        if surrogate is not None:
                            emulate_orig = (
                                lambda mask, z_obs, noise, i: emulate_surrogate(
                                    surrogate,
                                    mask,
                                    z_obs,
                                    label=label,  # noqa: B023
                                )
                            )
                        else:
                            emulate_orig = None
                    else:
                        method = "autoencoder"

                    with torch.no_grad():
                        if method in ("diffusion", "surrogate", "flow_matching"):
                            # Emulate trajectory
                            z_hat = emulate_rollout(
                                emulate,
                                z,
                                window=cfg.trajectory.length,
                                rollout=z.shape[1],
                                context=cfg.val_eval.context,
                                overlap=cfg.val_eval.overlap,
                                batch=cfg.val_eval.samples,
                                crps_noise_emb=cfg.noise_emb_features,
                            )
                            if video_plotting and surrogate is not None:
                                z_hat_surrogate = emulate_rollout(
                                    emulate_orig,
                                    z,
                                    window=cfg.trajectory.length,
                                    rollout=z.shape[1],
                                    context=cfg.val_eval.context,
                                    overlap=cfg.val_eval.overlap,
                                    batch=1,
                                    crps_noise_emb=0,
                                )
                            else:
                                z_hat_surrogate = None
                        else:
                            z_hat = z.expand(cfg.val_eval.samples, *z.shape)

                        if "euler" in cfg.dataset.name:
                            chunk_size = 128
                        elif "gravity" in cfg.dataset.name:
                            chunk_size = 128
                        elif "shear_flow" in cfg.dataset.name:
                            chunk_size = 128
                        else:
                            chunk_size = 256

                        x_hat = decode_traj(
                            autoencoder,
                            z_hat,
                            batched=True,
                            noisy=False,
                            chunk_size=chunk_size,
                        )
                        if z_hat_surrogate is not None:
                            x_hat_surrogate = decode_traj(
                                autoencoder,
                                z_hat_surrogate,
                                batched=True,
                                noisy=False,
                                chunk_size=chunk_size,
                            )
                        else:
                            x_hat_surrogate = None

                    del z_hat, z_hat_surrogate

                    ## Postprocess
                    x = postprocess(x, dim=-spatial - 2)
                    x_hat = postprocess(x_hat, dim=-spatial - 2)
                    if x_hat_surrogate is not None:
                        x_hat_surrogate = postprocess(x_hat_surrogate, dim=-spatial - 2)

                    if split == "test" and rank == 0:
                        lines = []
                    # Compute metrics
                    for field in range(x.shape[0]):
                        for t in range(cfg.val_eval.context - 1, x.shape[1]):
                            true_t = (
                                t - cfg.val_eval.context + 1
                            ) * cfg.trajectory.stride
                            u, v = x[field, t], x_hat[:, field, t]

                            # Moments (these don't change with ensemble size)
                            m1 = torch.mean(u)
                            m2 = torch.mean(u**2)

                            # Fourier analysis (ground truth doesn't change)
                            p_u, k = isotropic_power_spectrum(u, spatial=spatial)
                            bins = torch.logspace(k[0].log2(), -1.0, steps=4, base=2)

                            # Storage for ensemble-dependent metrics
                            metrics_by_ensemble = {}

                            # Iterate over ensemble sizes
                            for ensemble_i in ensemble_sizes_to_save:
                                v_subset = v[
                                    :ensemble_i
                                ]  # Use first ensemble_i samples

                                # Spread
                                if ensemble_i > 1:
                                    # see https://doi.org/10.1175/JHM-D-14-0008.1
                                    spread = torch.mean(
                                        torch.square(
                                            v_subset - torch.mean(v_subset, dim=0)
                                        )
                                    )
                                    spread = torch.sqrt(
                                        (ensemble_i + 1) / (ensemble_i - 1) * spread
                                    )
                                else:
                                    spread = 0.0

                                # Skill metrics
                                se = torch.square(u - torch.mean(v_subset, dim=0))
                                mse = torch.mean(se)
                                rmse = torch.sqrt(mse)
                                nrmse = torch.sqrt(mse / (torch.mean(u**2) + 1e-6))
                                vrmse = torch.sqrt(mse / (torch.var(u) + 1e-6))

                                # Spread_skill ratio
                                spread_skill = (spread + 1e-3) / (rmse + 1e-3)

                                # Fourier metrics
                                p_v, _ = isotropic_power_spectrum(
                                    v_subset, spatial=spatial
                                )
                                p_v = torch.mean(p_v, dim=0)
                                se_p = torch.square(1 - (p_v + 1e-6) / (p_u + 1e-6))
                                rmse_p = torch.sqrt(torch.mean(se_p))

                                fourier_extras = []
                                for i in range(4):
                                    if i < 3:
                                        mask = torch.logical_and(
                                            bins[i] <= k, k <= bins[i + 1]
                                        )
                                    else:
                                        mask = bins[i] <= k
                                    fourier_extras.append(
                                        torch.sqrt(torch.mean(se_p[mask])).item()
                                    )

                                extras = []
                                ## Wasserstein
                                w_uv = ot.lp.wasserstein_1d(
                                    u.flatten(),
                                    v_subset.flatten(),
                                    p=1.0,
                                )
                                extras.append(w_uv.item())
                                extras.append(None)

                                # CRPS
                                crps_i = surrogate_loss(
                                    predictions=v_subset[None, :, None, None, :],
                                    target=u[None, None, None, :],
                                    mask=None,
                                    mem_efficient=False,
                                ).mean()

                                # Store all metrics for this ensemble size
                                metrics_by_ensemble[ensemble_i] = {
                                    "vrmse": vrmse.item(),
                                    "rmse": rmse.item(),
                                    "nrmse": nrmse.item(),
                                    "spread": (
                                        spread
                                        if isinstance(spread, float)
                                        else spread.item()
                                    ),
                                    "spread_skill": spread_skill.item(),
                                    "rmse_p": rmse_p.item(),
                                    "rmse_p_low": fourier_extras[0],
                                    "rmse_p_mid": fourier_extras[1],
                                    "rmse_p_high": fourier_extras[2],
                                    "rmse_p_sub": fourier_extras[3],
                                    "crps": crps_i.item(),
                                    "wasserstein": extras[0],
                                    "emd": extras[1],
                                }

                            # For CSV output (test split) - store one row per ensemble size
                            if split == "test" and rank == 0:
                                # Store one row for each ensemble size
                                for ensemble_i in ensemble_sizes_to_save:
                                    if (
                                        ensemble_i in metrics_by_ensemble
                                    ):  # Only if we computed it
                                        ensemble_metrics = metrics_by_ensemble[
                                            ensemble_i
                                        ]

                                        line = f"{runid},state,{compression:.1f},crps,{cfg.noise_emb_features},"
                                        line += f"train,None,{cfg.val_eval.context},{cfg.val_eval.overlap},1.0,"
                                        line += f"{split},{index},{cfg.val_eval.start},{cfg.seed},"
                                        line += f"{field_names[field]},{(t - cfg.val_eval.context + 1) * cfg.trajectory.stride},False,"
                                        line += (
                                            f"{ensemble_i},"  # Add ensemble size here
                                        )
                                        line += f"{m1},{m2},"
                                        line += f"{ensemble_metrics['spread']},{ensemble_metrics['spread_skill']},"
                                        line += f"{ensemble_metrics['rmse']},{ensemble_metrics['nrmse']},{ensemble_metrics['vrmse']},"
                                        line += f"{ensemble_metrics['rmse_p']},"

                                        # Add Fourier extras
                                        fourier_values = [
                                            ensemble_metrics["rmse_p_low"],
                                            ensemble_metrics["rmse_p_mid"],
                                            ensemble_metrics["rmse_p_high"],
                                            ensemble_metrics["rmse_p_sub"],
                                        ]
                                        line += ",".join(map(str, fourier_values)) + ","
                                        # Add CRPS value for this ensemble size
                                        line += f"{ensemble_metrics['crps']}" + ","

                                        # Add Wasserstein and Sliced EMD
                                        extra_values = [
                                            ensemble_metrics["wasserstein"],
                                            ensemble_metrics["emd"],
                                        ]
                                        line += ",".join(map(str, extra_values))

                                        # Add label parameters
                                        if hasattr(label, "tolist"):
                                            line += (
                                                f",{','.join(map(str, label.tolist()))}"
                                            )

                                        line += "\n"
                                        lines.append(line)

                            # Store metrics for logging (per-field and field-averaged)
                            for ensemble_i in ensemble_sizes_to_save:
                                ensemble_idx = (
                                    f"_{ensemble_i}"
                                    if ensemble_i != cfg.val_eval.samples
                                    else ""
                                )

                                # Store per-field metrics for this ensemble size
                                for metric_name, value in metrics_by_ensemble[
                                    ensemble_i
                                ].items():
                                    if metric_name not in ["wasserstein", "emd"]:
                                        all_metrics[
                                            f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                        ].append((true_t, value))

                    # After processing all samples, compute field-averaged metrics across all timesteps
                    for ensemble_i in ensemble_sizes_to_save:
                        ensemble_idx = (
                            f"_{ensemble_i}"
                            if ensemble_i != cfg.val_eval.samples
                            else ""
                        )

                        for metric_name in metric_names:
                            # Get all field keys for this metric and ensemble size
                            field_keys = [
                                f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                for field in range(num_fields)
                            ]

                            # Check if all field keys exist and have data
                            if all(
                                key in all_metrics and all_metrics[key]
                                for key in field_keys
                            ):
                                # Get all timesteps from the first field (they should all be the same)
                                timesteps = [t for t, _ in all_metrics[field_keys[0]]]

                                # Create the field-averaged list
                                field_averaged_values = []

                                for i, timestep in enumerate(timesteps):
                                    # Collect values from all fields for this timestep index
                                    field_values = []
                                    for field_key in field_keys:
                                        if i < len(all_metrics[field_key]):
                                            _, value = all_metrics[field_key][
                                                i
                                            ]  # Get value at index i
                                            field_values.append(value)

                                    # Average across fields for this timestep
                                    if (
                                        len(field_values) == num_fields
                                    ):  # Ensure we have all fields
                                        avg_value = sum(field_values) / len(
                                            field_values
                                        )
                                        field_averaged_values.append(
                                            (timestep, avg_value)
                                        )

                                # Store the field-averaged values
                                all_metrics[f"{metric_name}{ensemble_idx}_all"] = (
                                    field_averaged_values
                                )

                    if split == "test" and rank == 0:
                        # Save
                        stats_path = (
                            runpath / "test_stats.csv"
                        )  # Single file for all samples
                        with FileLock(str(stats_path) + ".lock"):
                            with open(stats_path, "a") as f:  # Use write mode
                                # Only write header if file is empty/new
                                if stats_path.stat().st_size == 0:
                                    header = "run,target,compression,method,noise_emb_size,settings,guidance,context,overlap,speed,"
                                    header += (
                                        "split,index,start,seed,field,time,relative,"
                                    )
                                    header += "ensemble_size,"  # Add this column
                                    header += "m1,m2,spread,spread_skill,rmse,nrmse,vrmse,rmse_p,"
                                    header += (
                                        "rmse_p_low,rmse_p_mid,rmse_p_high,rmse_p_sub,"
                                    )
                                    header += "crps,"  # Single CRPS value per row
                                    header += "wasserstein,emd"
                                    if (
                                        hasattr(label, "tolist")
                                        and len(label.tolist()) > 0
                                    ):
                                        param_names = getattr(
                                            cfg.dataset,
                                            "param_names",
                                            [
                                                f"param_{i}"
                                                for i in range(len(label.tolist()))
                                            ],
                                        )
                                        header += f",{','.join(param_names)}"
                                    header += "\n"
                                    f.write(header)
                                f.writelines(lines)

                    if (record > 0) and (
                        (idx in [0, 1, 2]) or (index in [82, 30, 170])
                    ):
                        if x.shape[-1] < x.shape[-2]:
                            x, x_hat = x.mT, x_hat.mT
                            if x_hat_surrogate is not None:
                                x_hat_surrogate = x_hat_surrogate.mT
                        if x_hat_surrogate is not None:
                            frames = torch.stack(
                                (x, x_hat_surrogate[0], *x_hat[:record])
                            )
                        else:
                            frames = torch.stack((x, *x_hat[:record]))
                        frames = rearrange(frames, "N C L H W -> L N C H W")
                        # Get rid of the first conditioning
                        frames = frames[cfg.val_eval.context :]

                        # Ensure the directory exists before saving the movie
                        movie_dir = runpath / f"epoch_{epoch_idx}"
                        movie_dir.mkdir(parents=True, exist_ok=True)
                        draw_movie(
                            frames,
                            file=(
                                movie_dir
                                / f"{runname}_{target}_{split}_{index:06d}_{cfg.val_eval.start:03d}_{cfg.val_eval.context}_{cfg.val_eval.overlap}_{cfg.seed}.mp4"
                            ),
                            fps=4.0 / cfg.trajectory.stride,
                            isolate={2},
                        )

                        del frames
                    del x, x_hat
                # Log aggregated rollout metrics across all samples
                if rank == 0 and len(all_metrics[f"{metric_names[0]}_all"]) > 0:
                    # Helper function to split by time ranges
                    def split_by_time_range(data, ranges):
                        """Split data into time ranges. ranges is list of (min, max) tuples."""
                        splits = [[] for _ in ranges]
                        for timestep, value in data:
                            for i, (t_min, t_max) in enumerate(ranges):
                                if t_min <= timestep <= t_max:
                                    splits[i].append(value)
                                    break
                        return splits

                    # Helper functions using NumPy
                    def compute_mean(data_list):
                        values = np.array(
                            [v for idx, v in data_list if idx >= cfg.val_eval.context]
                        )
                        return float(np.mean(values))

                    def compute_median(data_list):
                        values = np.array(
                            [v for idx, v in data_list if idx >= cfg.val_eval.context]
                        )
                        return float(np.median(values))

                    # Define time ranges: 1-40, 41-80, 81-200
                    if "rayleigh" in cfg.dataset.name:
                        time_ranges = [(1, 40), (41, 80), (81, 200)]
                        range_names = ["1_40", "41_80", "81_200"]
                    elif "euler" in cfg.dataset.name:
                        time_ranges = [(1, 20), (21, 60), (61, 100)]
                        range_names = ["1_20", "21_60", "61_100"]
                    elif "shear_flow" in cfg.dataset.name:
                        time_ranges = [(1, 40), (41, 80), (81, 200)]
                        range_names = ["1_40", "41_80", "81_200"]
                    else:
                        raise ValueError(
                            f"No time ranges defined for dataset {cfg.dataset.name}."
                        )

                    # Log aggregate metrics (averaged across all fields) - for ALL ensemble sizes
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            for metric_name in metric_names:
                                metric_key = f"{metric_name}{ensemble_idx}_all"
                                if (
                                    metric_key in all_metrics
                                    and all_metrics[metric_key]
                                ):
                                    data = all_metrics[metric_key]
                                    logs[
                                        f"{split}/rollout/{metric_name}{ensemble_idx}_mean"
                                    ] = compute_mean(data)
                                    logs[
                                        f"{split}/rollout/{metric_name}{ensemble_idx}_median"
                                    ] = compute_median(data)

                                    # Time range breakdowns
                                    splits = split_by_time_range(data, time_ranges)
                                    for i, range_name in enumerate(range_names):
                                        if splits[i]:
                                            logs[
                                                f"{split}/rollout/{metric_name}{ensemble_idx}_mean_{range_name}"
                                            ] = sum(splits[i]) / len(splits[i])
                                            logs[
                                                f"{split}/rollout/{metric_name}{ensemble_idx}_median_{range_name}"
                                            ] = float(np.median(splits[i]))

                    # Log per-field metrics - for ALL ensemble sizes
                    for ensemble_i in ensemble_sizes_to_save:
                        if ensemble_i <= cfg.val_eval.samples:  # Safety check
                            ensemble_idx = (
                                f"_{ensemble_i}"
                                if ensemble_i != cfg.val_eval.samples
                                else ""
                            )

                            for field in range(num_fields):
                                for metric_name in metric_names:
                                    metric_key = f"{metric_name}{ensemble_idx}_field_{field_names[field]}"
                                    if (
                                        metric_key in all_metrics
                                        and all_metrics[metric_key]
                                    ):
                                        data = all_metrics[metric_key]
                                        logs[
                                            f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_mean"
                                        ] = compute_mean(data)
                                        logs[
                                            f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_median"
                                        ] = compute_median(data)

                                        # Time range breakdowns per field
                                        splits = split_by_time_range(data, time_ranges)
                                        for i, range_name in enumerate(range_names):
                                            if splits[i]:
                                                logs[
                                                    f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_mean_{range_name}"
                                                ] = sum(splits[i]) / len(splits[i])
                                                logs[
                                                    f"{split}/rollout/{metric_name}{ensemble_idx}_field_{field_names[field]}_median_{range_name}"
                                                ] = float(np.median(splits[i]))
        # Stats for validation loss
        ft_surrogate.eval()

        losses = []

        with torch.no_grad():
            for _ in range(cfg.valid.epoch_size // cfg.valid.batch_size):
                x, label = get_well_inputs(next(valid_loader), device=device)
                x = preprocess(x)
                x = rearrange(x, "B L ... C -> B C L ...")

                mask = random_context_mask(x, **cfg.trajectory.context)

                preds = get_ensemble_predictions(
                    ft_surrogate,
                    inputs=x,
                    mask=mask,
                    label=label,
                    noise_emb_features=cfg.noise_emb_features,
                    n_samples=cfg.ensemble_size,
                    device=device,
                )

                loss = surrogate_loss(
                    predictions=preds,
                    target=x,
                    mask=mask,
                    mem_efficient=False,
                )
                losses.append(loss)

        losses = torch.stack(losses)

        if use_ddp:
            if rank == 0:
                losses_list = [torch.empty_like(losses) for _ in range(world_size)]
            else:
                losses_list = None

            dist.gather(losses, losses_list, dst=0)

            if rank == 0:
                losses = torch.cat(losses_list).cpu()
        else:
            losses = losses.cpu()

        epoch_end = timing.time()
        epoch_duration = epoch_end - epoch_start
        epoch_times.append(epoch_duration)

        if rank == 0:
            #     # Calculate timing statistics
            avg_epoch_time = sum(epoch_times) / len(epoch_times)
            remaining_epochs = cfg.train.epochs - (epoch_idx + 1)
            eta_seconds = avg_epoch_time * remaining_epochs
            eta_hours = eta_seconds / 3600
            total_elapsed = (epoch_end - start_time) / 3600

            # Log epoch statistics
            logs["valid/loss/mean"] = losses.mean().item()
            logs["valid/loss/std"] = losses.std().item()

            epochs.set_postfix(
                lt=logs["train/loss/mean"],
                lv=logs["valid/loss/mean"],
            )

            # Update main progress bar with comprehensive info
            epochs.set_postfix(
                train_loss=f"{logs['train/loss/mean']:.4f}",
                valid_loss=f"{logs['valid/loss/mean']:.4f}",
                epoch_time=f"{epoch_duration:.1f}s",
                avg_time=f"{avg_epoch_time:.1f}s",
                eta=f"{eta_hours:.1f}h",
                elapsed=f"{total_elapsed:.1f}h",
            )

            # Log timing to wandb
            logs["timing/epoch_duration"] = epoch_duration
            logs["timing/avg_epoch_duration"] = avg_epoch_time
            logs["timing/eta_hours"] = eta_hours
            logs["timing/total_elapsed_hours"] = total_elapsed
            logs["timing/epochs_per_hour"] = 1 / (avg_epoch_time / 3600)

            run.log(logs, step=counter["epoch"])

            counter["epoch"] += 1

        if use_ddp:
            del losses_list
        del losses

        ## LR scheduler
        scheduler.step()

        ## Checkpoint
        if rank == 0:
            if use_ddp:
                state = ft_surrogate.module.state_dict()
            else:
                state = ft_surrogate.state_dict()

            # Save complete checkpoint with optimizer and scheduler states
            checkpoint = {
                "model_state_dict": state,
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "epoch": counter["epoch"],
                "best_valid_loss": best_valid_loss,
                "counter": counter,
                # Add any other training state you want to preserve
            }
            torch.save(checkpoint, runpath / "checkpoint.pth")

            # Also save just the model weights for compatibility
            torch.save(state, runpath / "state.pth")

            if logs["valid/loss/mean"] < best_valid_loss:
                best_valid_loss = logs["valid/loss/mean"]
                torch.save(state, runpath / "state_best.pth")
                torch.save(checkpoint, runpath / "checkpoint_best.pth")

            del state, checkpoint

        if use_ddp:
            dist.barrier(device_ids=[device_id])

    # W&B
    if rank == 0:
        run.finish()

    # DDP
    if use_ddp:
        dist.destroy_process_group()


if __name__ == "__main__":
    # Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("overrides", nargs="*", type=str)
    parser.add_argument(
        "--single_gpu", action="store_true", help="Run on a single GPU (for debugging)"
    )
    parser.add_argument("--debug", action="store_true", help="Debug mode")

    args = parser.parse_args()

    # Config
    if args.debug:
        cfg = compose(
            config_file="./experiments/configs/train_crps_debug.yaml",
            overrides=args.overrides,
        )
        wandb_project = "lola_prob_debug"
    else:
        cfg = compose(
            config_file="./experiments/configs/train_crps.yaml",
            overrides=args.overrides,
        )
        wandb_project = "lola_prob"

    # Job
    runid = wandb.util.generate_id()
    if args.debug:
        runid = "debug_" + runid
        with open_dict(cfg):
            cfg["debug"] = True
    else:
        with open_dict(cfg):
            cfg["debug"] = False

    if args.single_gpu:
        # Run directly on single GPU without SLURM
        print(f"Running on single GPU without SLURM, runid: {runid}")
        train(runid, cfg, wandb_project=wandb_project)
    else:
        print("Scheduling SLURM job...")
        if cfg.compute.nodes > 1:
            interpreter = f"torchrun --nnodes {cfg.compute.nodes} --nproc-per-node {cfg.compute.gpus} --rdzv_backend=c10d --rdzv_endpoint=$SLURMD_NODENAME:12345 --rdzv_id=$SLURM_JOB_ID"
        else:
            interpreter = (
                f"torchrun --nnodes 1 --nproc-per-node {cfg.compute.gpus} --standalone"
            )

        print(f"Interpreter: {interpreter}")

        job_config = dawgz.job(
            f=partial(train, runid, cfg, wandb_project),
            name=f"crps {runid}",
            nodes=cfg.compute.nodes,
            cpus=cfg.compute.cpus_per_gpu * cfg.compute.gpus,
            gpus=cfg.compute.gpus,
            ram=cfg.compute.ram,
            time=cfg.compute.time,
            partition=cfg.server.partition,
            constraint=cfg.server.constraint,
            exclude=cfg.server.exclude,
        )

        print(f"Job config created: {job_config}")

        try:
            if args.debug:
                backend = "async"
            else:
                backend = "slurm"
            dawgz.schedule(
                job_config,
                name=f"training crps {runid}",
                # backend="slurm",
                backend=backend,
                interpreter=interpreter,
                env=[
                    "export OMP_NUM_THREADS=" + f"{cfg.compute.cpus_per_gpu}",
                    "export WANDB_SILENT=true",
                    "export XDG_CACHE_HOME=$HOME/.cache",
                    "export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
                ],
            )
            print("Job scheduled successfully")
        except Exception as e:
            print(f"Error scheduling job: {e}")
            import traceback

            traceback.print_exc()
