model:
  class_path: desiddpayne.models.AIONCrossAttentionProbing
  init_args:
    n_outputs: 16
    num_heads: 12
    model_path: "data/aion/dec24/base"
    num_encoder_tokens: 600
    lr: 1e-4
data:
  class_path: desiddpayne.dataset.DESIDDPayneDatasetModule
  init_args:
    data_dir: "data"
    input_fields: ["tok_spectrum_desi", "tok_parallax"]
    batch_size: 256
    num_workers: 0
    version: "1"
trainer:
  max_epochs: 10
  accelerator: gpu
  precision: "bf16-mixed"
  log_every_n_steps: 1
  callbacks:
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: "step"
  logger:
    class_path: WandbLogger
    init_args:
      name: "base_xatt_spec_plx"
      project: "aion_eval_desiddpayne"
      
      save_dir: "data/AION_Eval/results"
  default_root_dir: "data/AION_Eval/results"
