# Base config for experiments
# NOTE: does NOT include output_dir or ckpt

seed: 1
tpu: true
download: true
cache: true

no_val: true
no_sample: true
no_fvd: true
no_metrics: true

# Training
batch_size: 32
num_workers: 4
lr: 0.0001
lr_schedule: "constant"
weight_decay: 0.00001
total_steps: 1000000
warmup_steps: 5000
test_interval: 10000
eval_interval: 100000
viz_interval: 100000
log_interval: 100

half_precision: false

# Data
data_path: "encoded_h5py_dataset/minerl_marsh_v2"
eval_seq_len: 300
seq_len: 300
image_size: 128
clip_grad_norm: 100
channels: 3

num_partitions: 1
rng_keys: ["sample"]
batch_keys: ["video", "actions"]

# Model
model: "cwvae"
vqvae_ckpt: "/home/TODO/logs/hier_video/minerl_marsh_vqgan_jax"

mode: "vq"
enc_cnn_kernels: [4, 4, 4]
enc_cnn_filters: [256, 512, 1024]

decoder_type: "resnet"
dec_depths: [512, 256]
dec_blocks: 8

levels: 3
tmp_abs_factor: 6
dec_stddev: 0.1

enc_dense_layers: 3
enc_dense_embed_size: 1024

cell_stoch_size: 256
cell_deter_size: 1024
cell_embed_size: 1024
cell_min_stddev: 0.001

dropout_actions: false
use_actions: true
action_dim: 6
action_embed_dim: 128
open_loop_ctx: 36
n_cond: 0
decode_fraction: 0.1
