# Base config for experiments
# NOTE: does NOT include output_dir or ckpt

seed: 1
tpu: true
download: true
cache: true

no_val: true
no_sample: false
no_fvd: true

# Training
batch_size: 32
num_workers: 4
lr: 0.0001
weight_decay: 0.00001
lr_schedule: "cosine"
total_steps: 1000000
warmup_steps: 5000
test_interval: 10000
eval_interval: 100000
viz_interval: 100000
log_interval: 100

half_precision: false

# Data
data_path: "encoded_ae_h5py_dataset/minerl_marsh_v2"
eval_seq_len: 300
seq_len: 300
image_size: 128
clip_grad_norm: null
channels: 3

num_partitions: 1
rng_keys: ["dropout", "sample", "noise"]
batch_keys: ["video", "actions"]

# Model
model: "latent_fdm"
ae_ckpt: "/home/TODO/logs/hier_video/minerl_marsh_ae_jax"

mode: "vq"
scale_factor: 0.23149166745591104
decode_fraction: 1.0
unet:
  model_channels: 128
  num_res_blocks: [1, 1, 2, 2]
  num_head_dim: 64
  attention_resolutions: [4, 2]
  action_dim: 6
  action_embed_dim: 16
  dropout: 0.1
  channel_mult: [1, 2, 2, 2]
  use_scale_shift_norm: true

# Actions
use_actions: true
dropout_actions: false

# Sampling
num_steps: 50
sample_method: "ddim"
sampler_kwargs:
  gamma: 1.0
open_loop_ctx: 36
