# Base config for experiments
# NOTE: does NOT include output_dir or ckpt

seed: 1
tpu: true
download: true
cache: true

no_val: true
no_sample: true
no_fvd: true

# Training
batch_size: 32
num_workers: 4
lr: 0.0001
weight_decay: 0.000001
lr_schedule: "constant"
total_steps: 1000000
warmup_steps: 10000
test_interval: 10000
eval_interval: 100000
viz_interval: 100000
log_interval: 100

half_precision: false

# Data
data_path: "encoded_h5py_dataset/minerl_marsh_v2"
eval_seq_len: 300
seq_len: 300
latent_size: 4
image_size: 128
clip_grad_norm: null

num_partitions: 1
rng_keys: ["dropout", "sample"]
batch_keys: ["video", "actions"]

# Model
model: "perceiver_ar"
vqvae_ckpt: "/home/TODO/logs/minerl_marsh_vqgan_jax"

cross_attn:
  num_heads: 8
  head_dim: 64
  dropout_rate: 0.

transformer:
  embed_dim: 1024
  mlp_dim: 4096
  num_heads: 16
  num_layers: 16
  dropout: 0.
  attention_dropout: 0.
  attention_type: "full"

action_dim: 6
action_embed_dim: 16
use_actions: true
dropout_actions: false

open_loop_ctx: 5
n_cond: 0
