# ==============================================================================
# TRAINING
# ==============================================================================
epochs: 2000  # 500
lr: 1.0E-4
sizes: {bs: 64, k: 1}  # 100


# ----------- LOSS -----------
kl_factor: 1
kl_end_warmup: 20


# ==============================================================================
# JOINT MODEL
# ==============================================================================
learn_prior: False
stoc_dim: {x1: [100], x2: [100]}
stoc_dist: 'normal'


# ==============================================================================
# VAE: X1
# ==============================================================================
x1_input_shape: [3, 64, 64]


# ----------- DETERMINISTIC LAYER -----------
x1_det_specs_bu: [[
  { 't': 'conv', 'c': 32, 'k': 4, 's': 2, 'p': 1 },
  { 't': 'conv', 'c': 64, 'k': 4, 's': 2, 'p': 1 },
  { 't': 'conv', 'c': 128, 'k': 4, 's': 2, 'p': 1 },
  { 't': 'conv', 'c': 128, 'k': 1, 's': 1, 'p': 0 },  # extra
  { 't': 'conv', 'c': 256, 'k': 4, 's': 2, 'p': 1 },
  {'t': 'dense', 'out': 512}
]]

x1_det_specs_td:  [[
  {'t': 'convt', 'c': 3, 'k': 4, 's': 2, 'p': 1, 'op': 0},
  {'t': 'convt', 'c': 32, 'k': 4, 's': 2, 'p': 1, 'op': 0},
  {'t': 'convt', 'c': 64, 'k': 4, 's': 2, 'p': 1, 'op': 0},
  { 't': 'convt', 'c': 64, 'k': 1, 's': 1, 'p': 0, 'op': 0 },  # extra
  {'t': 'convt', 'c': 64, 'k': 1, 's': 1, 'p': 0, 'op': 0},  # extra
  {'t': 'convt', 'c': 128, 'k': 4, 's': 2, 'p': 1, 'op': 0},
  {'t': 'dense', 'out': 4096, 'reshape': [256, 4, 4]}
]]


# ----------- STOCHASTIC LAYER -----------
x1_stoc_specs: [{t: 'dense'}]
x1_stoc_upsampling: [{t: 'dense', reshape: [512]}]
x1_merge_layer: ''


# ----------- RECONSTRUCTION LAYER -----------
x1_rec_specs: {t: 'conv', k: 3}
x1_rec_dist: 'sigmoid'
x1_rec_factor: 1


# ----------- GENERIC LAYER -----------
x1_nonlin: swish


# ==============================================================================
# VAE: X2
# ==============================================================================
x2_input_shape: [1024,]


# ----------- DETERMINISTIC LAYER -----------
x2_det_specs_bu: [[
  {t: 'dense', out: 1024}, {t: 'dense', out: 768},
  { t: 'dense', out: 1024 }, { t: 'dense', out: 768 },  # extra
  {t: 'dense', out: 768}, {t: 'dense', out: 768}
]]
x2_det_specs_td: [ [
  {t: 'dense', out: 1024}, {t: 'dense', out: 1024},
  { t: 'dense', out: 1024 }, { t: 'dense', out: 1024 },  # extra
  {t: 'dense', out: 1024}, {t: 'dense', out: 1024},  # extra
  {t: 'dense', out: 768}, {t: 'dense', out: 768}
]]


# ----------- STOCHASTIC LAYER -----------
x2_stoc_specs: [{t: 'dense'}]
x2_stoc_upsampling: ''
x2_merge_layer: ''
x2_prior_layer: ''


# ----------- RECONSTRUCTION LAYER -----------
x2_rec_specs: {t: 'dense'}
x2_rec_dist: 'normal'
x2_rec_factor: 1


# ----------- GENERIC LAYER -----------
x2_nonlin: leaky_relu


# ==============================================================================
# MISC
# ==============================================================================
n_modalities: 2
eval_bs: 1
exp_name: 'flat_new'
trial: 'trial'
seed: 23