# configs/models/attn_oceans_default.yaml
# Model configuration for ocean currents
#
# Key considerations:
# - Small particle count (~111), so attention is tractable
# - 2D data, so smaller hidden dims work well
# - Short sequences (9 timesteps), focus on interpolation

# Training mode
train_mode: rollout_next_k_from_0
eval_mode: interpolate  # Will be overridden by CLI --holdout-marginal

# Rollout parameters
num_epochs: 50000
rollout_k: null  # null: use max_train_steps
substeps_per_dt: 1
dt_sim: 0.2
integrator: v  # leapfrog

# Loss function
loss_type: geom_sinkhorn

# Velocity initialization
vel: bundle  # Use ground-truth velocities from data

# Friction (start with small learnable friction)
friction: 0.1
learnable_friction: true
friction_lr: 1.0e-4

# Architecture: Attention
arch: attn_flow

# Architecture kwargs (attn_heads can be overridden via --set)
attn_heads: 1
attn_hidden_dim: 16
attn_layers: 4
ff-dim: 128
use_time: false
d_time: 16
use_com: false  # Center of mass features

# Optimizer
lr: 1.0e-4

# Evaluation & checkpointing
eval_every: 500
ckpt_every: 0
gif_every: 1000

# Particles for eval/gif
particles_per_batch: null  # Use all particles (small dataset)
particles_eval: null
particles_gif: 400
gif_p0_idx: 0
gif_frame_skip: 1
gif_fps: 1

# EMA (Exponential Moving Average)
use_ema: true
ema_decay: 0.999
# W&B
wandb: true
wandb_project: oceans_more_data_forecast
