# @package __global__

defaults:
  - /solver/default
  - /conditioner: none
  - _self_
  - override /dset: audio/default

autocast: true
autocast_dtype: float16

solver: CSM_raw_fMRI

compression_model_checkpoint: null
use_layer: null

cache:
  path:
  write: false
  write_shard: 0
  write_num_shards: 1

dataset:
  batch_size: 64
  num_workers: 10
  segment_duration: 64
  min_fMRI_duration: null
  max_fMRI_duration: null
  tr_precision: 0.2
  max_read_retry: 1
  max_tr: 256.0
  return_info: true
  normalize: false
  train:
    num_samples: 569650  # raw num: 11393 * 50
  valid:
    num_samples: 4740  # raw num: 474 * 10
  evaluate:
    batch_size: 32
    num_samples: 1130  # raw num: 113 * 10
  generate:
    batch_size: 16
    num_samples: 16

generate:
  every: 25
  num_workers: 5
  block_num: 2
  lm:
    prompted_samples: true
    unprompted_samples: false
    gen_gt_samples: false
    prompt_duration: null   # if not set, will use dataset.generate.segment_duration / 4
    gen_duration: null      # if not set, will use dataset.generate.segment_duration
    remove_prompts: false

evaluate:
  every: 25
  num_workers: 5
  metrics:
    base: true  # same evaluation metrics with train & valid

checkpoint:
  save_last: true
  save_every: 25
  keep_last: 10
  keep_every_states: null

optim:
  epochs: 500
  updates_per_epoch: 2000
  lr: 1e-4  # same settings with paper
  optimizer: adamw
  max_norm: 1.0
  eager_sync: true
  adam:
    betas: [0.9, 0.95]
    weight_decay: 0.1
    eps: 1e-8
  ema:
    use: true
    updates: 10
    device: cuda

schedule:
  lr_scheduler: cosine
  cosine:
    warmup: 4000
    lr_min_ratio: 0.0
    cycle_length: 1.0