#!/usr/bin/env python3
"""
train.py

Train model to learn WLM from population dynamics from data_generator.py.

Outputs in outdir/run_name/:
  - config.json
  - metrics.jsonl
  - ckpt_stepXXXXXXX.pt (periodic)
  - final.pt
  - compare_stepXXXXXXX.gif (optional, periodic)

Supports:
  - rollout_uniform
  - rollout_incremental

Evaluation modes:
  - forecast: Train on first portion of data, evaluate on held-out future (default)
  - interpolate: Hold out a middle marginal, train on rest, evaluate on held-out
"""
import numpy as np
import torch
import torch.nn as nn
import argparse
import matplotlib.pyplot as plt
import os
import time
import copy
from pathlib import Path
from typing import Any, Dict, Optional, List
from functools import partial
from sklearn.metrics import pairwise_distances

import wandb

from rollout import train_rollout_anchor_p0_randk, train_rollout_uniform, train_rollout_incremental, LearnableFriction
from mechanics import build_vel_provider
from potential_energy_models import build_model_and_kwargs, make_accel_from_potential

from dice import maybe_make_dice_diagnostic_gif, train_or_load_dice_bundle
from parse_save_helpers import (
    _parse_args_with_config, get_device, set_seed, ensure_dir, print_velocity_diagnostics,
    dump_json, append_jsonl, _find_complex, _wandb_sanitize, load_dice_models, now_str,
    TrainCallbackState, make_uniform_epoch_callback, make_incremental_callbacks,
    save_ckpt as save_ckpt_impl,
    do_eval as do_eval_impl,
)

from plot_utils import plot_holdout_scatter, plot_multi_holdout_scatter ,maybe_gif as maybe_gif_impl
from losses import get_w1

# ============================================================
# EMA (Exponential Moving Average)
# ============================================================
## Make sure to apply this to friction weights
class EMA:
    """
    Lightweight EMA - stores only state dict, not a full model copy.
    Saves ~50% memory compared to keeping two models.

    Usage:
        ema = EMA(model, decay=0.999)
        for batch in data:
            loss = train_step(model, batch)
            ema.update(model)  # call after optimizer.step()

        # For eval: temporarily apply EMA weights
        with ema.apply(model):
            evaluate(model)
    """

    def __init__(self, model: nn.Module, friction_module: Optional[nn.Module] = None, decay: float = 0.999):
        self.decay = decay
        self.step = 0
        # 1. Shadow for Model
        # Store only state dict values (not full model graph)
        self.shadow = {}
        for k, v in model.state_dict().items():
            self.shadow[k] = v.clone().detach()

        # 2. Shadow for Friction
        self.friction_shadow = {}
        if friction_module is not None:
            for k, v in friction_module.state_dict().items():
                self.friction_shadow[k] = v.clone().detach()

    @torch.no_grad()
    def update(self, model: nn.Module, friction_module: Optional[nn.Module] = None):
        """Update EMA parameters. Call after optimizer.step()."""
        self.step += 1

        model_state = model.state_dict()
        for k, v in model_state.items():
            if k in self.shadow:
                if v.dtype.is_floating_point:
                    self.shadow[k].mul_(self.decay).add_(v, alpha=1.0 - self.decay)
                else:
                    self.shadow[k].copy_(v)
        # Update Friction
        if friction_module is not None and self.friction_shadow:
            for k, v in friction_module.state_dict().items():
                if k in self.friction_shadow:
                    # Note: Friction params are small, ensure no underflow/instability if needed
                    self.friction_shadow[k].mul_(self.decay).add_(v, alpha=1.0 - self.decay)

    @torch.no_grad()
    def apply_to(self, model: nn.Module):
        """Load EMA weights into model (for eval)."""
        model.load_state_dict(self.shadow)

    @torch.no_grad()
    def restore(self, model: nn.Module, backup: dict):
        """Restore original weights from backup."""
        model.load_state_dict(backup)

    class _ApplyContext:
        def __init__(self, ema, model, friction_module):
            self.ema = ema
            self.model = model
            self.friction_module = friction_module
            self.backup = {}
            self.friction_backup = {}

        def __enter__(self):
            # Apply to Model
            self.backup = {k: v.clone() for k, v in self.model.state_dict().items()}
            self.model.load_state_dict(self.ema.shadow)

            # Apply to Friction
            if self.friction_module is not None and self.ema.friction_shadow:
                self.friction_backup = {k: v.clone() for k, v in self.friction_module.state_dict().items()}
                self.friction_module.load_state_dict(self.ema.friction_shadow)

        def __exit__(self, *args):
            # Restore Model
            self.model.load_state_dict(self.backup)
            # Restore Friction
            if self.friction_module is not None and self.friction_backup:
                self.friction_module.load_state_dict(self.friction_backup)

    def apply(self, model: nn.Module, friction_module: Optional[nn.Module] = None):
        """Context manager to temporarily use EMA weights for eval."""
        return self._ApplyContext(self, model, friction_module)

    def state_dict(self) -> dict:
        return {
            'step': self.step,
            'decay': self.decay,
            'shadow': self.shadow,
            # CRITICAL ADDITION:
            'friction_shadow': self.friction_shadow
        }

    def load_state_dict(self, state_dict: dict):
        self.step = state_dict['step']
        self.decay = state_dict.get('decay', self.decay)
        self.shadow = state_dict['shadow']
        # CRITICAL ADDITION:
        self.friction_shadow = state_dict.get('friction_shadow', self.friction_shadow)


# ============================================================
# Train: test splits for forecast and interpolation
# ============================================================

def partition_data_forecast(
        X_em: torch.Tensor,
        V_em: Optional[torch.Tensor],
        time_grid: torch.Tensor,
        train_fraction: float,
) -> Dict[str, Any]:
    """
    Partition data for forecast mode: first portion for training, rest for eval.

    Returns dict with:
        X_train, V_train, t_train: Training data (contiguous first portion)
        X_full, V_full, time_grid: Full data for evaluation
        T_train_plus_1: Number of training time steps
        max_train_steps: Maximum training horizon
        mode_info: Dict with partition details
    """
    num_p0, N, T_plus_1, d = X_em.shape

    # Compute training horizon
    T_train_plus_1 = max(2, int(round(T_plus_1 * train_fraction)))
    T_train_plus_1 = min(T_train_plus_1, T_plus_1)  # Cap at full data
    max_train_steps = T_train_plus_1 - 1

    # Slice training data
    X_train = X_em[:, :, :T_train_plus_1, :].contiguous().clone()
    t_train = time_grid[:T_train_plus_1].contiguous().clone()

    V_train = None
    if V_em is not None:
        V_train = V_em[:, :, :T_train_plus_1, :].contiguous().clone()

    mode_info = {
        "eval_mode": "forecast",
        "train_fraction": train_fraction,
        "T_train_plus_1": T_train_plus_1,
        "T_total": T_plus_1,
        "train_times": t_train.cpu().tolist(),
    }

    print(f"[Forecast] Train on t=[0..{T_train_plus_1 - 1}] ({train_fraction * 100:.0f}%), "
          f"eval on t=[{T_train_plus_1 - 1}..{T_plus_1 - 1}]")

    return {
        "X_train": X_train,
        "V_train": V_train,
        "t_train": t_train,
        "X_full": X_em,
        "V_full": V_em,
        "time_grid": time_grid,
        "T_train_plus_1": T_train_plus_1,
        "max_train_steps": max_train_steps,
        "mode_info": mode_info,
    }


def partition_data_interpolate(
        X_em: torch.Tensor,
        V_em: Optional[torch.Tensor],
        time_grid: torch.Tensor,
        holdout_indices: List[int],  # Changed from single int to list
) -> Dict[str, Any]:
    """
    Partition data for interpolate mode: hold out multiple marginals.

    Args:
        holdout_indices: List of indices to hold out (e.g., [1, 3, 5, 7])

    Returns dict with:
        X_train, V_train, t_train: Training data (excluding holdouts)
        X_full, V_full, time_grid: Full data for evaluation
        holdout_indices: List of held-out marginal indices
        T_train_plus_1: Number of training time steps
        max_train_steps: Maximum training horizon
        mode_info: Dict with partition details
    """
    num_p0, N, T_plus_1, d = X_em.shape

    # Validate holdout indices
    holdout_indices = sorted(set(int(h) for h in holdout_indices))
    for h in holdout_indices:
        if h <= 0 or h >= T_plus_1:
            raise ValueError(f"holdout index must be in [1, {T_plus_1 - 1}], got {h}")

    # Training indices: all except holdouts
    train_time_idx = [i for i in range(T_plus_1) if i not in holdout_indices]

    # Slice training data
    X_train = X_em[:, :, train_time_idx, :].contiguous().clone()
    t_train = time_grid[train_time_idx].contiguous().clone()

    V_train = None
    if V_em is not None:
        V_train = V_em[:, :, train_time_idx, :].contiguous().clone()

    T_train_plus_1 = len(train_time_idx)
    max_train_steps = T_train_plus_1 - 1

    holdout_times = [float(time_grid[h].item()) for h in holdout_indices]

    mode_info = {
        "eval_mode": "interpolate",
        "holdout_indices": holdout_indices,
        "holdout_times": holdout_times,
        "T_train_plus_1": T_train_plus_1,
        "T_total": T_plus_1,
        "train_time_idx": train_time_idx,
        "train_times": t_train.cpu().tolist(),
    }

    print(f"[Interpolate] Holdout indices: {holdout_indices}")
    print(f"[Interpolate] Holdout times: {holdout_times}")
    print(f"[Interpolate] Train on {T_train_plus_1} marginals: {train_time_idx}")

    return {
        "X_train": X_train,
        "V_train": V_train,
        "t_train": t_train,
        "X_full": X_em,
        "V_full": V_em,
        "time_grid": time_grid,
        "holdout_indices": holdout_indices,
        "train_time_idx": train_time_idx,
        "T_train_plus_1": T_train_plus_1,
        "max_train_steps": max_train_steps,
        "mode_info": mode_info,
    }


# ============================================================
# CLI
# ============================================================

def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser("train.py")

    # Config
    p.add_argument("--config", type=str, default=None, help="YAML/JSON config file")
    p.add_argument("--set", action="append", default=None,
                   help="Override config with dot-keys, e.g. --set arch.attn_heads=2")

    # IO
    p.add_argument("--data", type=str, required=False, default=None, help="Path to .pt bundle from data_generator.py")
    p.add_argument("--outdir", type=str, required=False, default=None, help="Base output directory")
    p.add_argument("--run-name", type=str, default="", help="Optional run name; default timestamp")
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])

    # Training mode
    p.add_argument("--train-mode", type=str, default="rollout_next_k_from_0",
                   choices=["rollout_next_k_from_0", "rollout_uniform", "rollout_incremental"])

    # Evaluation mode: forecast vs interpolate
    p.add_argument("--eval-mode", type=str, default="forecast",
                   choices=["forecast", "interpolate"],
                   help="forecast: train on first portion, eval on future. "
                        "interpolate: hold out middle marginal, train on rest.")
    p.add_argument("--train-fraction", type=float, default=0.5,
                   help="Fraction of data for training in forecast mode (default: 0.5)")
    p.add_argument("--holdout-marginals", type=int, nargs="+", default=None,
                   help="Marginal indices to hold out in interpolate mode (e.g., 1 3 5 7)")

    # Rollout knobs
    p.add_argument("--dt-base", type=float, default=None, help="Overrides dt from bundle meta if set")
    p.add_argument("--substeps-per-dt", type=int, default=1)
    p.add_argument("--max-train-steps", type=int, default=None,
                   help="Max training steps. If None, auto-calibrates based on eval mode.")
    p.add_argument("--integrator", type=str, default="v", choices=["v", "x", "euler"])
    p.add_argument("--particles-per-batch", type=int, default=None)
    p.add_argument("--dt-sim", type=float, default=0.2)

    # Uniform rollout
    p.add_argument("--num-epochs", type=int, default=5000)
    p.add_argument("--rollout-k", type=int, default=None)
    # Incremental rollout
    p.add_argument("--epochs-per-step", type=int, default=1000)

    # Loss
    p.add_argument("--loss-type", type=str, default="geom_sinkhorn",
                   choices=["mmd", "sw2", "sinkhorn", "geom_sinkhorn", "geom_gaussian", "geom_energy"])
    p.add_argument("--kernel-bw2", type=float, default=None,
                   help="Legacy: sigma2. GeomLoss: blur. If omitted, uses bundle blur.")

    # GeomLoss knobs
    p.add_argument("--geom-p", type=int, default=2)
    p.add_argument("--geom-scaling", type=float, default=0.9)
    p.add_argument("--geom-debias", action="store_true", default=True)
    p.add_argument("--geom-backend", type=str, default=None)

    # Velocity
    p.add_argument("--vel", type=str, default='bundle',
                   choices=["bundle", "zero", "dice"],
                   help="Initial velocity mode.")

    # DICE
    p.add_argument("--dice-hidden", type=int, default=128)

    # Friction
    p.add_argument("--friction", type=float, default=0.0, help="Used if not learnable (or as init if learnable)")
    p.add_argument("--learnable-friction", action="store_true")
    p.add_argument("--friction-lr", type=float, default=1e-2)

    # Force clamp
    p.add_argument("--max-force", type=float, default=None)

    # Architecture
    p.add_argument("--arch", type=str, default="transformer",
                   choices=["transformer", "better_mlp", "mlp", "attn_flow", "hybrid_mlp_attn", "weighted_hybrid","gated_hybrid"])
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--weight-decay", type=float, default=1e-4,
                   help="Weight decay coefficient for AdamW (default: 1e-4)")

    # MLPs
    p.add_argument("--hidden-dim", type=int, default=256)
    p.add_argument("--mlp-layers", type=int, default=4)

    # Attention
    p.add_argument("--attn-hidden-dim", type=int, default=32)
    p.add_argument("--attn-layers", type=int, default=4)
    p.add_argument("--attn-heads", type=int, default=1)
    p.add_argument("--use-time", action="store_true", default=False)
    p.add_argument("--use-com", action="store_true", default=False)
    p.add_argument("--d-time", type=int, default=16)
    p.add_argument("--ff-dim", type=int, default=512)
    p.add_argument("--dropout", type=float, default=0.0)

    # Hybrid
    p.add_argument("--hybrid-weights", type=float, nargs=2, default=[1.0, 1.0],
                   metavar=("W_ONE", "W_INT"))
    p.add_argument("--normalize-int-by-n", action="store_true", default=False)
    # Gated hybrid
    p.add_argument("--gate-init", type=float, default=-3.0,
                   help="Initial gate logit for gated_hybrid (sigmoid(-3)≈0.05)")
    # Save/eval/gif frequencies
    p.add_argument("--ckpt-every", type=int, default=500)
    p.add_argument("--eval-every", type=int, default=500)
    p.add_argument("--gif-every", type=int, default=0)

    p.add_argument("--particles-eval", type=int, default=None)
    p.add_argument("--gif-p0-idx", type=int, default=0)
    p.add_argument("--particles-gif", type=int, default=1000)
    p.add_argument("--gif-frame-skip", type=int, default=5)
    p.add_argument("--gif-fps", type=int, default=5)

    # W&B
    p.set_defaults(wandb=True)
    p.add_argument("--no-wandb", dest="wandb", action="store_false")
    p.add_argument("--wandb-project", type=str, default="wlf")
    p.add_argument("--wandb-entity", type=str, default=None)
    p.add_argument("--wandb-name", type=str, default="")
    p.add_argument("--wandb-tags", type=str, default="")

    # EMA
    p.add_argument("--use-ema", action="store_true", default=False,
                   help="Use exponential moving average for model weights")
    p.add_argument("--ema-decay", type=float, default=0.999,
                   help="EMA decay rate (higher = slower/more stable)")

    return p


# ============================================================
# Main
# ============================================================

def main() -> None:
    parser = build_parser()
    args = _parse_args_with_config(parser)

    if args.data is None or args.outdir is None:
        raise SystemExit("train.py: --data and --outdir are required.")

    device = get_device(args.device)
    set_seed(int(args.seed))

    # ---- Load data ----
    bundle = torch.load(args.data, map_location="cpu", weights_only=False)
    X_em = bundle["X_em_torch"].to(device=device, dtype=torch.float32)
    time_grid = bundle["time_grid"].to(device=device, dtype=torch.float32)
    V_em = bundle.get("V_em_torch", None)

    # Sanitize X/V consistency
    if V_em is not None:
        V_em = V_em.to(device=device, dtype=torch.float32)
        valid_x = torch.isfinite(X_em).all(dim=-1)
        valid_v = torch.isfinite(V_em).all(dim=-1)
        valid_common = valid_x & valid_v
        nan_t = torch.tensor(float('nan'), device=device, dtype=X_em.dtype)
        mask_expand = valid_common.unsqueeze(-1)
        X_em = torch.where(mask_expand, X_em, nan_t)
        V_em = torch.where(mask_expand, V_em, nan_t)
        print(f"[Data] Enforced X/V consistency. Kept {valid_common.sum().item()} valid pairs.")

    meta = bundle.get("meta", {})
    blur = float(bundle.get("blur", meta.get("blur", 0.2)))
    num_p0, N, T_plus_1, d = X_em.shape
    print(f"Loaded: {num_p0} populations, {T_plus_1} marginals, {N} samples each, d={d}")

    # ---- Partition data based on eval_mode ----
    eval_mode = str(args.eval_mode).lower()

    if eval_mode == "interpolate":
        if args.holdout_marginals is None:
            raise SystemExit("--eval-mode=interpolate requires --holdout-marginals")
        holdout_indices = [int(h) for h in args.holdout_marginals]
        partition = partition_data_interpolate(X_em, V_em, time_grid, holdout_indices)
    else:  # forecast
        train_fraction = float(args.train_fraction)
        partition = partition_data_forecast(X_em, V_em, time_grid, train_fraction)

    X_train = partition["X_train"]
    V_train = partition["V_train"]
    t_train = partition["t_train"]
    T_train_plus_1 = partition["T_train_plus_1"]
    max_train_steps = partition["max_train_steps"]
    mode_info = partition["mode_info"]

    # Override max_train_steps if specified
    if args.max_train_steps is not None:
        max_train_steps = min(int(args.max_train_steps), max_train_steps)

    # ---- Resolved dt / kernel bw ----
    if args.dt_base is not None:
        dt_base = float(args.dt_base)
    else:
        dt_base = float(meta.get("dt", (time_grid[1] - time_grid[0]).item()))
    print(f"dt_base: {dt_base}")
    kernel_bw2 = float(args.kernel_bw2) if args.kernel_bw2 is not None else float(blur)
    vel_mode = args.vel

    # ---- Output directory ----
    outdir = Path(args.outdir)
    ensure_dir(outdir)
    run_name = args.run_name.strip() or outdir.name or f"{args.train_mode}_{args.arch}_{now_str()}_seed{args.seed}"

    # SDPA math kernels
    if torch.cuda.is_available():
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)

    # ---- Friction ----
    friction_obj: Optional[LearnableFriction] = None
    if bool(args.learnable_friction):
        friction_obj = LearnableFriction(init_gamma=float(args.friction)).to(device)
        friction_param = friction_obj
    else:
        friction_param = float(args.friction)
    if friction_obj is not None:
        print("[friction] init gamma =", float(friction_obj.friction_tensor().detach().cpu().item()))

    def friction_value() -> float:
        return friction_obj.value() if friction_obj is not None else float(friction_param)

    # ---- Build model ----
    model, model_kwargs = build_model_and_kwargs(args, d=d)
    model = model.to(device)

    # ---- EMA ----
    ema: Optional[EMA] = None
    if bool(getattr(args, 'use_ema', False)):
        # CHANGE: Pass friction_obj here
        ema = EMA(model, friction_module=friction_obj, decay=float(getattr(args, 'ema_decay', 0.999)))
        print(f"[EMA] Initialized with decay={ema.decay}")


    # # ---- Optimizer ----
    # if friction_obj is not None:
    #     optimizer = torch.optim.Adam([
    #         {"params": model.parameters(), "lr": float(args.lr)},
    #         {"params": [friction_obj.theta], "lr": float(args.friction_lr)},
    #     ])
    # else:
    #     optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr))

    wd = float(getattr(args, 'weight_decay', 1e-4))  # Safety getattr if you didn't update parser yet

    if friction_obj is not None:
        optimizer = torch.optim.AdamW([
            {
                "params": model.parameters(),
                "lr": float(args.lr), #1e-4
                "weight_decay": wd
            },
            {
                "params": [friction_obj.theta],
                "lr": float(args.friction_lr), #1e-3
                "weight_decay": 0.0  # <--- Important: No decay for friction
            },
        ])
    else:
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=float(args.lr),
            weight_decay=wd
        )

    # ---- Config ----
    config: Dict[str, Any] = vars(args).copy()
    config.update({
        "resolved": {
            "dt_base": float(dt_base),
            "kernel_bw2": float(kernel_bw2),
            "bundle_blur": float(blur),
            "meta": meta,
            "num_p0": int(num_p0),
            "N": int(N),
            "steps": int(max_train_steps),
            "d": int(d),
            "device": str(device),
            "vel_mode": str(vel_mode),
        },
        "model_kwargs": model_kwargs,
        "mode_info": mode_info,
    })
    _find_complex(config)
    dump_json(outdir / "config.json", config)

    # ---- W&B init ----
    wb_run = None
    if bool(args.wandb):
        tags = [t.strip() for t in str(args.wandb_tags).split(",") if t.strip()]
        tags.append(eval_mode)  # Tag with eval mode
        wb_name = args.wandb_name.strip() or run_name
        os.environ.setdefault("WANDB_SILENT", "true")
        wb_config = _wandb_sanitize(config)
        wb_run = wandb.init(
            project=str(args.wandb_project),
            entity=(None if args.wandb_entity in (None, "", "none") else str(args.wandb_entity)),
            name=wb_name,
            dir=os.environ.get("WANDB_DIR", str(outdir)),
            config=wb_config,
            tags=tags,
        )

    metrics_path = outdir / "metrics.jsonl"

    # ---- DICE models (if estimating velocity) ----
    data_path = Path(args.data)
    master_dir = outdir
    for p in outdir.parents:
        if p.name == "train":
            master_dir = p.parent
            break

    data_tag = data_path.parent.name
    dice_models: Optional[List[nn.Module]] = None
    if vel_mode == "dice":
        dice_dir = master_dir / "dice" / data_tag
        ensure_dir(dice_dir)
        dice_models_pt = dice_dir / "dice_models.pt"
        dice_bundle_path = dice_dir / "dice_bundle.pt"

        if dice_models_pt.exists():
            print(f"[dice] load teacher: {dice_models_pt}")
            dice_models = load_dice_models(
                str(dice_models_pt), device=device, d=int(d),
                hidden=int(args.dice_hidden), num_p0=int(num_p0),
            )
        else:
            print(f"[dice] train/load teacher bundle: {dice_bundle_path}")
            dice_models, _ = train_or_load_dice_bundle(
                bundle_path=dice_bundle_path, X_em_torch=X_em, time_grid=time_grid,
                d=int(d), hidden=int(args.dice_hidden), device=device,
                wandb_run=wb_run, log_prefix=f"_{run_name}",
                steps=10000, lr=1e-3, lr_end=1e-5, clip_norm=1.0,
                batch_size_t=int(time_grid.numel()), batch_size_x=128,
            )

        dice_gif_path = dice_dir / "dice_diagnostic.gif"
        maybe_make_dice_diagnostic_gif(
            save_path=str(dice_gif_path), X_em=X_em, time_grid=time_grid,
            dice_models=dice_models, pop_idx=int(args.gif_p0_idx),
            wandb_run=wb_run, wandb_step=0,
        )

    # ---- Velocity provider ----
    vel_provider = build_vel_provider(vel_mode, meta, dice_models, V_em=V_train, time_grid=t_train)

    # Print velocity diagnostics
    print_velocity_diagnostics(V_em, V_train, vel_mode, eval_mode)

    geom_backend = str(args.geom_backend) if args.geom_backend is not None else "auto"

    # ---- Accel function ----
    accel_train = make_accel_from_potential(model, create_graph=True, max_force=args.max_force)

    # ---- Save checkpoint function ----
    def save_ckpt(*, tag: str, step_idx: int, friction_value: float, friction_raw=None):
        save_ckpt_impl(
            outdir=outdir, tag=tag, step_idx=step_idx, model=model,
            optimizer=optimizer, model_kwargs=model_kwargs, meta=meta,
            config=config, train_mode=str(args.train_mode),
            learnable_friction=bool(args.learnable_friction),
            friction_value=friction_value, friction_raw=friction_raw,
            run_name=run_name, wb_run=wb_run, ema=ema,
        )

    # ---- Evaluation function ----
    if eval_mode == "interpolate":
        from mechanics import leapfrog_auto, resolve_gamma
        holdout_indices = partition["holdout_indices"]
        train_time_idx = partition["train_time_idx"]

        def do_eval(step_idx):
            print(f"--- Interpolate Eval (Holdouts {holdout_indices}) @ step {step_idx} ---")
            friction_val = friction_value()

            def _run_eval_inner():
                nonlocal model
                model.eval()

                plot_dir = outdir / "eval_plots"
                ensure_dir(plot_dir)

                dt_substeps = float(dt_base) / int(args.substeps_per_dt)
                dt_requested = float(args.dt_sim)
                if 0 < dt_requested < float(dt_base):
                    dt_micro = dt_requested
                else:
                    dt_micro = dt_substeps

                metrics = {"friction_value": friction_val, "eval_dt_micro": dt_micro}
                accel_eval = make_accel_from_potential(model, create_graph=False, max_force=args.max_force)

                # Storage for plotting
                all_true = []  # List of (x_true, holdout_idx)
                all_pred = []  # List of (x_pred, holdout_idx)
                w1_values = []

                for h_idx in holdout_indices:
                    # Find the previous training time (for path B)
                    prev_train_idx = max([t for t in train_time_idx if t < h_idx], default=0)

                    # Get ground truth at holdout
                    mask_h = torch.isfinite(X_em[0, :, h_idx, :]).all(dim=-1)
                    y_true = X_em[0, :, h_idx, :][mask_h].to(device)

                    # Get start position and velocity
                    t_start_val = float(time_grid[prev_train_idx].item())
                    t_end_val = float(time_grid[h_idx].item())

                    mask_s = torch.isfinite(X_em[0, :, prev_train_idx, :]).all(dim=-1)
                    x0 = X_em[0, :, prev_train_idx, :][mask_s].to(device)

                    if "bundle" in str(vel_mode).lower() and V_em is not None:
                        v0 = V_em[0, :, prev_train_idx, :][mask_s].to(device)
                    else:
                        v0 = vel_provider(0, x0, t_start_val) if vel_provider else torch.zeros_like(x0)

                    # Integrate
                    steps = int(round((t_end_val - t_start_val) / dt_micro))
                    if steps <= 0:
                        x_pred = x0
                    else:
                        x_pred = leapfrog_auto(
                            x0, v0, accel_eval, dt_micro, steps,
                            resolve_gamma(friction_param), return_all=False, t_start=t_start_val
                        )
                        if isinstance(x_pred, tuple):
                            x_pred = x_pred[0]

                    # Clean predictions
                    mask_p = torch.isfinite(x_pred).all(dim=-1)
                    x_pred_clean = x_pred[mask_p]

                    # Compute W1
                    x_np = x_pred_clean.detach().cpu().numpy()
                    y_np = y_true.detach().cpu().numpy()

                    if x_np.shape[0] == 0 or y_np.shape[0] == 0:
                        w1 = float("nan")
                    else:
                        M = pairwise_distances(x_np, y_np, metric='euclidean')
                        w1 = get_w1(M)

                    metrics[f"eval_w1_h{h_idx}"] = w1
                    w1_values.append(w1)

                    # Store for plotting
                    all_true.append((y_true.detach().cpu(), h_idx))
                    all_pred.append((x_pred_clean.detach().cpu(), h_idx))

                    print(f"  Holdout {h_idx} (t={time_grid[h_idx].item():.2f}): "
                          f"W1={w1:.4f}, start_idx={prev_train_idx}")

                # Compute mean W1
                valid_w1 = [w for w in w1_values if not np.isnan(w)]
                if valid_w1:
                    metrics["eval_w1_mean"] = float(np.mean(valid_w1))
                    metrics["eval_w1_std"] = float(np.std(valid_w1))
                else:
                    metrics["eval_w1_mean"] = float("nan")
                    metrics["eval_w1_std"] = float("nan")

                print(f"  Mean W1: {metrics['eval_w1_mean']:.4f} ± {metrics['eval_w1_std']:.4f}")

                # Create multi-holdout scatter plot
                if wb_run is not None or True:  # Always save locally
                    fig = plot_multi_holdout_scatter(
                        all_true, all_pred, time_grid,
                        title=f"Interpolation @ step {step_idx}"
                    )
                    fname = f"scatter_multi_step{step_idx:07d}.png"
                    local_path = plot_dir / fname
                    fig.savefig(local_path, dpi=150, bbox_inches='tight')
                    plt.close(fig)

                    if wb_run is not None:
                        wb_run.log({
                            "eval/scatter_multi": wandb.Image(str(local_path)),
                            "eval/w1_mean": metrics["eval_w1_mean"],
                        }, step=step_idx)

                model.train()
                return metrics

            # Run with EMA if available
            if ema is not None:
                print(f"  [Using EMA weights (step={ema.step})]")
                # CHANGE: Pass friction_obj to apply
                with ema.apply(model, friction_module=friction_obj):
                    return _run_eval_inner()
            else:
                return _run_eval_inner()
    else:
        # Forecast mode - use standard eval
        do_eval = partial(
            do_eval_impl,
            model=model,
            friction_value_fn=friction_value,
            X_em=X_em,  # Uses X_full from partition
            time_grid=time_grid,  # Uses full time_grid from partition
            dt_base=float(dt_base),
            substeps_per_dt=int(args.substeps_per_dt),
            loss_type=args.loss_type,
            kernel_bw2=float(kernel_bw2),
            vel_provider=vel_provider,
            friction=friction_param,
            verlet=str(args.integrator),
            geom_p=int(args.geom_p),
            geom_scaling=float(args.geom_scaling),
            geom_debias=bool(args.geom_debias),
            geom_backend=args.geom_backend,
            max_force=args.max_force,
            particles_eval=(None if args.particles_eval is None else int(args.particles_eval)),
            vel_mode=str(vel_mode),
            V_em=V_em,
            T_train_plus_1=int(T_train_plus_1),
            dt_sim=getattr(args, "dt_sim", None),
            ema=ema,
            eval_mode="forecast",  # NEW
            holdout_idx=None,  # NEW
        )


    maybe_gif = partial(
        maybe_gif_impl,
        gif_every=int(args.gif_every),
        gif_p0_idx=int(args.gif_p0_idx),
        particles_gif=(None if args.particles_gif is None else int(args.particles_gif)),
        gif_frame_skip=int(args.gif_frame_skip),
        gif_fps=int(args.gif_fps),
        substeps_per_dt=int(args.substeps_per_dt),
        integrator_name=str(args.integrator),
        max_force=args.max_force,
        model=model,
        X_em=X_em,
        time_grid=time_grid,
        dt_base=float(dt_base),
        vel_provider=vel_provider,
        vel_mode=str(vel_mode),
        V_em=V_em,
        friction=friction_param,
        outdir=outdir,
        device=device,
        wb_run=wb_run,
    )

    # ---- Training ----
    t0_wall = time.time()
    cb_state = TrainCallbackState()
    step_idx = 0

    # EMA update wrapper
    def wrap_callback_with_ema(callback_fn):
        if ema is None:
            return callback_fn

        def wrapped(*args, **kwargs):
            result = callback_fn(*args, **kwargs) if callback_fn else None
            # CHANGE: Pass friction_obj to update
            ema.update(model, friction_module=friction_obj)
            return result

        return wrapped

    if str(args.train_mode) == "rollout_next_k_from_0":
        rollout_k = int(args.rollout_k) if args.rollout_k is not None else int(max_train_steps)
        remaining = int(args.num_epochs)
        chunk = max(1, int(args.eval_every))

        while remaining > 0:
            cur = min(chunk, remaining)
            cb_state.uniform_step_base = int(step_idx)

            uniform_epoch_cb = make_uniform_epoch_callback(
                state=cb_state, args=args, wb_run=wb_run,
                save_ckpt=save_ckpt, maybe_gif=maybe_gif,
                friction_value_fn=friction_value,
                friction_raw_fn=(lambda: float(friction_obj.theta.detach().cpu().item())
                if friction_obj is not None else None),
            )

            def anchor_epoch_cb(*cb_args):
                if len(cb_args) == 3:
                    epoch, loss, fric = cb_args
                    return uniform_epoch_cb(int(epoch), float(loss), float(fric))
                elif len(cb_args) == 4:
                    epoch, _k, loss, fric = cb_args
                    return uniform_epoch_cb(int(epoch), float(loss), float(fric))
                else:
                    raise TypeError(f"Unexpected callback signature: {len(cb_args)} args")

            anchor_epoch_cb_ema = wrap_callback_with_ema(anchor_epoch_cb)

            tr = train_rollout_anchor_p0_randk(
                X_em_torch=X_train, time_grid=t_train, accel_train=accel_train,
                optimizer=optimizer, dt_base=float(dt_base), num_epochs=int(cur),
                max_train_steps=int(rollout_k), substeps_per_dt=int(args.substeps_per_dt),
                kernel_bw2=float(kernel_bw2), loss_type=args.loss_type,
                particles_per_batch=args.particles_per_batch, vel_provider=vel_provider,
                friction=friction_param, debug=False, name=run_name,
                verlet=str(args.integrator), geom_p=int(args.geom_p),
                geom_scaling=float(args.geom_scaling), geom_debias=bool(args.geom_debias),
                geom_backend=args.geom_backend, epoch_callback=anchor_epoch_cb_ema,
            )

            step_idx += int(cur)
            remaining -= int(cur)

            if wb_run is not None and isinstance(tr, dict):
                wb_run.log({
                    "train_summary/loss_avg": float(tr.get("train_loss_avg", 0.0)),
                    "train_summary/loss_last": float(tr.get("train_loss_last", 0.0)),
                }, step=int(cb_state.global_step))

            ev = do_eval(step_idx=int(step_idx))
            ev["wall_s"] = float(time.time() - t0_wall)
            append_jsonl(metrics_path, {"type": "eval", **ev})

            if wb_run is not None:
                # Handle both forecast and interpolate mode keys
                if "train_w1_avg" in ev:
                    # Forecast mode
                    wb_run.log({
                        "eval/train_w1": float(ev["train_w1_avg"]),
                        "eval/train_w1_se": float(ev.get("train_w1_se", 0.0)),
                        "eval/forecast_w1": float(ev["forecast_w1_avg"]),
                        "eval/forecast_w1_se": float(ev.get("forecast_w1_se", 0.0)),
                        "eval/friction": float(ev["friction_value"]),
                    }, step=int(cb_state.global_step))
                    print(f"[eval] step={step_idx} "
                          f"train_w1={ev['train_w1_avg']:.4f}±{ev.get('train_w1_se', 0.0):.4f} "
                          f"forecast_w1={ev['forecast_w1_avg']:.4f}±{ev.get('forecast_w1_se', 0.0):.4f}")
                else:
                    # Interpolate mode - uses pathA/pathB keys
                    w1_A = ev.get("eval_w1_pathA", float('nan'))
                    w1_B = ev.get("eval_w1_pathB", float('nan'))
                    wb_run.log({
                        "eval/w1_pathA": float(w1_A),
                        "eval/w1_pathB": float(w1_B),
                        "eval/friction": float(ev["friction_value"]),
                    }, step=int(cb_state.global_step))
                    print(f"[eval] step={step_idx} "
                          f"w1_pathA={w1_A:.4f} w1_pathB={w1_B:.4f}")

    elif str(args.train_mode) == "rollout_uniform":
        rollout_k = int(args.rollout_k) if args.rollout_k is not None else int(max_train_steps)
        remaining = int(args.num_epochs)
        chunk = max(1, int(args.eval_every))

        while remaining > 0:
            cur = min(chunk, remaining)
            cb_state.uniform_step_base = int(step_idx)

            uniform_epoch_cb = make_uniform_epoch_callback(
                state=cb_state, args=args, wb_run=wb_run,
                save_ckpt=save_ckpt, maybe_gif=maybe_gif,
                friction_value_fn=friction_value,
                friction_raw_fn=(lambda: float(friction_obj.theta.detach().cpu().item())
                if friction_obj is not None else None),
            )
            uniform_epoch_cb_ema = wrap_callback_with_ema(uniform_epoch_cb)

            tr = train_rollout_uniform(
                X_em_torch=X_train, time_grid=t_train, accel_train=accel_train,
                optimizer=optimizer, dt_base=float(dt_base), num_epochs=int(cur),
                max_train_steps=rollout_k, substeps_per_dt=int(args.substeps_per_dt),
                kernel_bw2=float(kernel_bw2), loss_type=args.loss_type,
                particles_per_batch=args.particles_per_batch, vel_provider=vel_provider,
                friction=friction_param, debug=False, name=run_name,
                verlet=str(args.integrator), geom_p=int(args.geom_p),
                geom_scaling=float(args.geom_scaling), geom_debias=bool(args.geom_debias),
                geom_backend=args.geom_backend, epoch_callback=uniform_epoch_cb_ema,
            )

            step_idx += int(cur)
            remaining -= int(cur)

            if wb_run is not None and isinstance(tr, dict):
                wb_run.log({
                    "train_summary/loss_avg": float(tr.get("train_loss_avg", 0.0)),
                    "train_summary/loss_last": float(tr.get("train_loss_last", 0.0)),
                }, step=int(cb_state.global_step))

            ev = do_eval(step_idx=int(step_idx))
            ev["wall_s"] = float(time.time() - t0_wall)
            append_jsonl(metrics_path, {"type": "eval", **ev})

            if wb_run is not None:
                # Handle both forecast and interpolate mode keys
                if "train_w1_avg" in ev:
                    # Forecast mode
                    wb_run.log({
                        "eval/train_w1": float(ev.get("train_w1_avg", float('nan'))),
                        "eval/train_w1_se": float(ev.get("train_w1_se", 0.0)),
                        "eval/forecast_w1": float(ev.get("forecast_w1_avg", float('nan'))),
                        "eval/forecast_w1_se": float(ev.get("forecast_w1_se", 0.0)),
                        "eval/friction": float(ev["friction_value"]),
                    }, step=int(cb_state.global_step))
                    print(f"[eval] step={step_idx} "
                          f"train_w1={ev.get('train_w1_avg', float('nan')):.4f}±{ev.get('train_w1_se', 0.0):.4f} "
                          f"forecast_w1={ev.get('forecast_w1_avg', float('nan')):.4f}±{ev.get('forecast_w1_se', 0.0):.4f}")
                else:
                    # Interpolate mode
                    w1_A = ev.get("eval_w1_pathA", float('nan'))
                    w1_B = ev.get("eval_w1_pathB", float('nan'))
                    wb_run.log({
                        "eval/w1_pathA": float(w1_A),
                        "eval/w1_pathB": float(w1_B),
                        "eval/friction": float(ev["friction_value"]),
                    }, step=int(cb_state.global_step))
                    print(f"[eval] step={step_idx} "
                          f"w1_pathA={w1_A:.4f} w1_pathB={w1_B:.4f}")

    elif str(args.train_mode) == "rollout_incremental":
        inc_epoch_cb, stage_end_cb, save_cb, gif_cb = make_incremental_callbacks(
            state=cb_state, args=args, wb_run=wb_run,
            metrics_path=metrics_path, append_jsonl_fn=append_jsonl,
            save_ckpt=save_ckpt, do_eval=do_eval, maybe_gif=maybe_gif,
            friction_value_fn=friction_value,
            friction_raw_fn=(lambda: float(friction_obj.theta.detach().cpu().item())
            if friction_obj is not None else None),
            t0_wall=float(t0_wall),
        )
        inc_epoch_cb_ema = wrap_callback_with_ema(inc_epoch_cb)

        train_rollout_incremental(
            X_em_torch=X_train, time_grid=t_train, accel_train=accel_train,
            optimizer=optimizer, dt_base=float(dt_base),
            epochs_per_step=int(args.epochs_per_step), max_train_steps=max_train_steps,
            substeps_per_dt=int(args.substeps_per_dt), kernel_bw2=float(kernel_bw2),
            loss_type=args.loss_type, particles_per_batch=args.particles_per_batch,
            vel_provider=vel_provider, friction=friction_param,
            debug=False, name=run_name, verlet=str(args.integrator),
            save_every_k=1, gif_every_k=int(args.gif_every),
            save_callback=save_cb, gif_callback=gif_cb,
            stage_end_callback=stage_end_cb, epoch_callback=inc_epoch_cb_ema,
        )

        step_idx = int(args.epochs_per_step) * int(max_train_steps)
        ev = do_eval(step_idx=int(step_idx))
        ev["wall_s"] = float(time.time() - t0_wall)
        append_jsonl(metrics_path, {"type": "eval", **ev})

        if wb_run is not None:
            wb_run.log({
                "eval/train_w1": float(ev.get("train_w1_avg", float('nan'))),
                "eval/train_w1_se": float(ev.get("train_w1_se", 0.0)),
                "eval/forecast_w1": float(ev.get("forecast_w1_avg", float('nan'))),
                "eval/forecast_w1_se": float(ev.get("forecast_w1_se", 0.0)),
                "eval/friction": float(ev["friction_value"]),
            }, step=int(cb_state.global_step))

        print(f"[eval] step={step_idx} "
              f"train_w1={ev.get('train_w1_avg', float('nan')):.4f}±{ev.get('train_w1_se', 0.0):.4f} "
              f"forecast_w1={ev.get('forecast_w1_avg', float('nan')):.4f}±{ev.get('forecast_w1_se', 0.0):.4f}")
        maybe_gif(step_idx)

    # ---- Final save ----
    save_ckpt(
        tag="final",
        step_idx=int(step_idx),
        friction_value=float(friction_value()),
        friction_raw=(float(friction_obj.theta.detach().cpu().item()) if friction_obj is not None else None),
    )

    if wb_run is not None:
        wb_run.finish()


if __name__ == "__main__":
    main()