defaults:
  - base
  - _self_

policy:
  _target_: ript.algos.quest.QueST
  autoencoder:
    _target_: ript.algos.quest_modules.skill_vae.SkillVAE
    action_dim: ${task.shape_meta.action_dim}
    encoder_dim: 256
    decoder_dim: 256
    skill_block_size: ${algo.skill_block_size}
    downsample_factor: ${algo.downsample_factor}
    attn_pdrop: 0.1
    use_causal_encoder: true
    use_causal_decoder: true
    encoder_heads: 4
    encoder_layers: 2
    decoder_heads: 4
    decoder_layers: 4
    vq_type: "fsq" # "vq" or "fsq"
    # fsq_level: [8,5,5,5]
    fsq_level: null
    codebook_dim: 512 # only used for vq
    codebook_size: ${algo.codebook_size} # if fsq level is null then it will automatically compute it according to this
  policy_prior:
    _target_: ript.algos.quest_modules.skill_gpt.SkillGPT
    action_dim: ${task.shape_meta.action_dim}
    start_token: 1000 # should be equal to actual fsq/vq codebook size
    vocab_size: 1000 # should be equal to actual fsq/vq codebook size
    block_size: ${eval:'${algo.skill_block_size} // ${algo.downsample_factor}'}
    n_layer: 6
    n_head: 6
    n_embd: ${algo.embed_dim}
    attn_pdrop: 0.1
    embd_pdrop: 0.1
    beam_size: 5 # value of k for top k sampling
    temperature: 1.0 # temperature for sampling
    device: ${device}
  stage: ${stage}
  loss_fn:
    _target_: torch.nn.L1Loss
  l1_loss_scale: ${algo.l1_loss_scale}
  action_horizon: ${algo.action_horizon}
  obs_reduction: cat
  device: ${device}

name: quest

# Put hyperparameters that require tuning here
lr: 0.0001
weight_decay: 0.0001
l1_loss_scale: 0 # scale factor used in finetuning stage

embed_dim: 384 # stage 2 transformer hidden dim
lowdim_embed_dim: 128 # each lowdim obs modality is embedded to this dim
image_embed_dim: 256 # each image obs vision encoder's output is embedded to this dim

codebook_size: 1024 # note for fsq this will be computed automatically to be 1000, see get_fsq_level function in quest/algos/quest_modules/skill_vae.py
skill_block_size: 32 # this is input sequence length to encoder
downsample_factor: 4

action_horizon: 8 # how many predicted actions to execute
frame_stack: 1

dataset:
  seq_len: ${algo.skill_block_size} # this denotes future timestep actions
  frame_stack: ${algo.frame_stack} # this denotes past timestep observations and actions
  obs_seq_len: 1 # this denotes future timestep image observations
  lowdim_obs_seq_len: null # this denotes future timestep lowdim observations
  load_obs_for_pretrain: false # since autoencoder training stage does not require obs
  load_next_obs: false # you know this
  get_pad_mask: false
  pad_seq_length: true
  load_state: false