# @package _global_

defaults:
  - data: data_2d
  - model: anchored_e2e
  # - model: probe_anchored_e2e
  # - model: probe_anchored_xa_e2e
  - callbacks: callbacks_2d
  - logger: wandb
  - wandb: wandb_2d
  - trainer: default
  - paths: default
  - extras: default
  - hydra: default
  # - experiment: null
  # - hparams_search: null
  # - optional local: default
  # - debug: null
  - override data/dataset: single_jet03
  - _self_

n_steps: 2  # number of steps to predict into the future

model_overrides:
  n_steps: ${n_steps}


wandb:
  run_id: a6ebjfxu

model:
  scheduler:
    warmup_steps: 1000
    min_lr: 1e-6

  optimizer:
    lr: 1e-5
    
trainer:
  max_steps: 10000
  check_val_every_n_epoch: 100
  precision: 16-mixed
  gradient_clip_val: 4

tags: ["finetune"]
# mode: anchored_end_to_end
mode: anchored_subsampled_end_to_end
ckpt_path: null  # continue from checkpoint

data:
  batch_size: 32  # for quick testing
  dataset:
    do_cache: true

probe_idcs: [9392, 16, 16207, 11974, 2068]

# does this fit to box_size?
# box_size: [[0, -0.125], [0.335, 0.125]]
supernode_radius: 5e-3
n_supernodes: 4096
n_anchors: ${n_supernodes}  # theoretically n_anchors <= n_supernodes
n_queries: 2048
attn_ctor: src.models.kappa_overrides.ag_dot_product_attention.AnchoredGatedDotProductAttention

model_dim: 384
latent_dim: ${model_dim}

seed: 42
task_name: "train"

# global conditioning; if null 
condition_dim: null

x_dim: 2
