wandb_config:
    project_name: "v2-RATE-MemoryMaze"
    wwandb: False
    wcomet: True

data_config:
  gamma: 1.0
  normalize: 1
  only_non_zero_rewards: True

training_config:
  learning_rate: 0.0003 # 1e-3
  lr_end_factor: 0.1 # *  old: 0.001
  beta_1: 0.9
  beta_2: 0.95
  weight_decay: 0.1
  batch_size: 64 # 128
  warmup_steps: 1000  # 1000
  final_tokens: 500000 # 1000000
  grad_norm_clip: 1.0 # 1.0 | float or 'None'
  epochs: 80 # 10
  ckpt_epoch: 1
  use_erl_stop: True
  # inference during training
  online_inference: True
  # ! WARNING ! INPORTANT INFO
  # * IF YOU WANT TO TRAIN DT WITH context_length=90 (or if you want to train RATE with
  # * 3 segments and context 30), set context_length=30 and sections=3
  context_length: 30 # if RATE/GRATE: L = L, if DT: L = sections * L
  sections: 3        # if RATE/GRATE: S = S, if DT: S = 1


model_config:
  mode: "memory_maze"
  STATE_DIM: 3
  ACTION_DIM: 6 # 2 if OHE else 4
  n_token: 10000 # vocab_size
  n_layer: 6 # 3 8
  n_head:  8 # 1 10
  n_head_ca: 4 # ! 4 | number of cross-attention heads
  d_model: 64 # 128
  d_head: 64 # 128 # divider of d_model
  d_inner: 64 # 128 # > d_model
  dropout: 0.5 # 0.2
  dropatt: 0.2 # 0.05
  mem_len: 300
  ext_len: 0
  tie_weight: False
  num_mem_tokens: 15 # 5 d_head * nmt = diff params
  mem_at_end: True
  mrv_act: 'relu'
  skip_dec_ffn: True # skip FFN in transformer decoder

online_inference_config:
  use_argmax: False
  episode_timeout: 1000 #3.2 # 90: 3.2 - top10
  desired_return_1: 18.1 # top-10 mean over 1k steps





