[data]
dequant_dist = "none"
int_dequant_factor = 0

[unimodmlp_params]
num_layers = 2
d_token = 4
n_head = 1
factor = 32
bias = true
dim_t = 1024
use_mlp = true

[train.main]
steps = 8000
lr = 0.001
weight_decay = 0
ema_decay = 0.997
batch_size = 4096
check_val_every = 2000
lr_scheduler = "reduce_lr_on_plateau"
factor = 0.90           # hyperparam for reduce_lr_on_plateau
reduce_lr_patience = 50        # hyperparam for reduce_lr_on_plateau
closs_weight_schedule = "anneal"
c_lambda = 1.0
d_lambda = 1.0

[sample]
batch_size = 10000
