[Custom]
# General options
log_level = "info"
no_progress_bar = True
device = "cuda"
seed = 1234

# Training options
epochs = 10
batch_size = 128
lr = 5e-4
checkpoint_every = 5000
dataset = 'dsprites'
experiment = 'custom'

# Model Options
model = 'Burgess'
loss = "betaH"
latent_dim = 10
rec_dist = "bernoulli"

# betaH Options; [1, 2, 4, 6, 8, 16]
betaH_B = 1 # 8

# betaB Options
betaB_initC = 0
#[5, 10, 25, 50, 75, 100]
betaB_finC = 25
# use 100 which is used by most implementation online; 1000 in Locatello
betaB_G = 1000

# factor Options; [10, 20, 30, 40, 50, 100]
factor_G = 6
lr_disc = 5e-5

# btcvae Options
btcvae_A = 1
btcvae_G = 1
# [1, 2, 4, 6, 8, 10]
btcvae_B = 6

# Evaluations Options
is_metrics = False
no_test = False
is_eval_only = False
eval_batchsize = 320

# ### DATASET COMMON ###
# same number of epochs for comparaisons

[Common_dsprites]
dataset = 'dsprites'
checkpoint_every = 100
is_metrics = True
# 300000步，约为26 epochs
;epochs = 30


[Common_chairs]
dataset = 'chairs'
checkpoint_every = 1000
;epochs = 300
is_metrics = False
[Common_celeba]
dataset = 'celeba'
checkpoint_every = 100
epochs = 200
[Common_mnist]
dataset = 'mnist'
checkpoint_every = 100
epochs = 400
[Common_fashion]
dataset = 'fashion'
checkpoint_every = 100
epochs = 400

# ### LOSS COMMON ###

[Common_VAE]
loss = "VAE"
lr = 5e-4
[Common_betaH]
loss = "betaH"
lr = 5e-4
[Common_betaB]
loss = "betaB"
lr = 1e-3
[Common_factor]
loss = "factor"
lr = 1e-4
[Common_btcvae]
loss = "btcvae"
lr = 5e-4

# ### EXPERIMENT SPECIFIC ###
# additional hyperparameter changes besides the common ones

# BETA H

[betaH_dsprites]
# beta as in paper
betaH_B = 4
[betaH_celeba]
# beta value as in from https://github.com/1Konny/Beta-VAE
betaH_B = 10
[betaH_chairs]
# beta value as in from https://github.com/1Konny/Beta-VAE
betaH_B = 4

[betaH_mnist]
# beta value as in from https://github.com/1Konny/Beta-VAE
betaH_B = 4

# BETA B

[betaB_dsprites]
# capacity as in paper
betaB_finC = 25
[betaB_celeba]
# capacity as in paper
betaB_finC = 50
[betaB_chairs]
betaB_finC = 25

# FACTOR

[factor_chairs]
factor_G = 3.2
lr_disc = 1e-5
# beta value as in from https://github.com/1Konny/FactorVAE/blob/master/utils.py
[factor_dsprites]
factor_G = 6.4
lr_disc = 1e-4
[factor_celeba]
factor_G = 6.4
lr_disc = 1e-5

# BTCVAE
# use all same values as factor
[btcvae_chairs]
btcvae_B = ${factor_chairs:factor_G}
[btcvae_dsprites]
btcvae_B = ${factor_dsprites:factor_G}
[btcvae_celeba]
btcvae_B = ${factor_celeba:factor_G}

# Other
# those don't use the common section by default (need to be <loss>_<data> to use)!

[best_celeba]
btcvae_A = -10
btcvae_B = 20
dataset = 'celeba'
loss = "btcvae"
epochs = ${Common_celeba:epochs}
checkpoint_every = ${Common_celeba:checkpoint_every}
lr = ${Common_btcvae:lr}
rec_dist = "laplace"

[best_dsprites]
btcvae_A = -5
btcvae_B = 10
dataset = 'dsprites'
loss = "btcvae"
epochs = ${Common_dsprites:epochs}
checkpoint_every = ${Common_dsprites:checkpoint_every}
lr = ${Common_btcvae:lr}

[debug]
epochs = 1
log_level = "debug"
no_test = True
