# Base config for experiments
# NOTE: does NOT include output_dir or ckpt

seed: 1
tpu: true
download: true
cache: true

no_val: true
no_sample: false
no_fvd: true

# Training
batch_size: 32
num_workers: 4
lr: 0.0001
lr_schedule: "cosine"
weight_decay: 0.00001
total_steps: 1000000
warmup_steps: 5000
test_interval: 10000
eval_interval: 100000
viz_interval: 100000
log_interval: 100

half_precision: false

# Data
data_path: "encoded_h5py_dataset/kinetics_100"
eval_seq_len: 100
seq_len: 100
image_size: 128
clip_grad_norm: null
channels: 3

num_shards: 4
rng_keys: ["dropout", "sample"]
batch_keys: ["video", "actions"]

# Model
model: "teco"
vqvae_ckpt: "/home/TODO/logs/hier_video/kinetics_vqgan_8192_jax"

encoder: # encoder / decoder are mirrored, with decoder depths reversed
  depths: [256, 512] # 16x16 -> 8x8
  blocks: 8

decoder: # encoder / decoder are mirrored, with decoder depths reversed
  depths: [256, 512] # 16x16 -> 8x8
  blocks: 10

z_ds: 2 # 8x8 -> 2x2
z_tfm_kwargs:
  embed_dim: 1536
  mlp_dim: 6144
  num_heads: 24
  num_layers: 24
  dropout: 0.
  attention_dropout: 0.

z_git:
  vocab_dim: 256
  mask_schedule: "cosine"
  tfm_kwargs:
    embed_dim: 1024
    mlp_dim: 4096
    num_heads: 16
    num_layers: 24
    dropout: 0.
    attention_dropout: 0.

trans_weight: 1
embedding_dim: 256
codebook:
  n_codes: 1024
  proj_dim: 32

n_cond: 1
decode_fraction: 0.1

# Actions
use_actions: false
action_dim: 6
action_embed_dim: 16
dropout_actions: false

# Sampling
T_draft: 8
T_revise: 8
M: 2
open_loop_ctx: 20
