# Base config for experiments
# NOTE: does NOT include output_dir or ckpt

seed: 1
tpu: true
download: true
cache: true

no_val: true
no_sample: false
no_fvd: true

# Training
batch_size: 32
num_workers: 4
lr: 0.0001
lr_schedule: "constant"
weight_decay: 0.00001
total_steps: 500000
warmup_steps: 5000
test_interval: 10000
eval_interval: 100000
viz_interval: 100000
log_interval: 100

half_precision: false

# Data
data_path: "encoded_h5py_dataset/minerl_marsh_v2"
eval_seq_len: 100
seq_len: 100
image_size: 128
clip_grad_norm: 100
channels: 3

num_partitions: 1
rng_keys: ["dropout", "sample"]
batch_keys: ["video", "actions"]

# Model
model: "teco"
vqvae_ckpt: "/home/TODO/logs/hier_video/minerl_marsh_vqgan_jax"

encoder: # encoder / decoder are mirrored, with decoder depths reversed
  depths: [256, 512] # 16x16 -> 8x8
  blocks: 4

decoder: # encoder / decoder are mirrored, with decoder depths reversed
  depths: [256, 512] # 16x16 -> 8x8
  blocks: 8

z_ds: 2 # 8x8 -> 2x2
z_tfm_kwargs:
  embed_dim: 1024
  mlp_dim: 4096
  num_heads: 16
  num_layers: 12
  dropout: 0.
  attention_dropout: 0.

z_git:
  vocab_dim: 256
  mask_schedule: "cosine"
  tfm_kwargs:
    embed_dim: 768
    mlp_dim: 3072
    num_heads: 12
    num_layers: 6
    dropout: 0.
    attention_dropout: 0.

trans_weight: 1
embedding_dim: 128
codebook:
  n_codes: 1024
  proj_dim: 32

n_cond: 1
decode_fraction: 0.25

# Actions
use_actions: true
action_dim: 6
dropout_actions: false
action_embed_dim: 16

# Sampling
T_draft: 8
T_revise: 8
M: 2
open_loop_ctx: 36
