# python code/generation/train.py -c generation/config/EP-VAE/EP-VAE-CIFAR.txt

# Output settings
exp_name = EP-VAE-CIFAR-10
#output_dir = generation/output

# General settings
device = cuda:0
dtype = float32
seed = 1
#load_checkpoint = generation/output/EP-VAE-CIFAR-10

# General training hyperparameters
num_epochs = 100
batch_size = 100
lr = 5e-4
weight_decay = 0
optimizer = RiemannianAdam
use_lr_scheduler = True
lr_scheduler_step = 50
lr_scheduler_gamma = 0.1

# KL Loss hyperparameters
kl_coeff = 0.024

# General validation/testing hyperparameters
batch_size_test = 1024
calc_fid_val = True
batch_size_fid_inception = 512

# General VAE hyperparameters
model = EP-VAE
enc_layers = 4
dec_layers = 3
z_dim = 128
initial_filters = 64

# Hyperbolic manifold settings
embed_K = 0.1
#learn_K = True

# Dataset settings
dataset = CIFAR-10 # or CIFAR-100
