_target_: lfads_torch.model.LFADS

# --------- architecture --------- #
encod_data_dim: 182
encod_seq_len: 140
recon_seq_len: 140
ext_input_dim: 0
ic_enc_seq_len: 0
ic_enc_dim: 64
ci_enc_dim: 0
ci_lag: 1
con_dim: 0
co_dim: 0
ic_dim: 64
gen_dim: 100
fac_dim: 40

# --------- readin / readout --------- #
readin:
  - _target_: torch.nn.Identity
readout:
  _target_: torch.nn.ModuleList
  modules:
    - _target_: lfads_torch.modules.readin_readout.FanInLinear
      in_features: ${model.fac_dim}
      out_features: 182

# --------- augmentation --------- #
train_aug_stack:
  _target_: lfads_torch.modules.augmentations.AugmentationStack
  transforms:
    - _target_: lfads_torch.modules.augmentations.CoordinatedDropout
      cd_rate: 0.3
      cd_pass_rate: 0.0
      ic_enc_seq_len: ${model.ic_enc_seq_len}
  batch_order: [0]
  loss_order: [0]
infer_aug_stack:
  _target_: lfads_torch.modules.augmentations.AugmentationStack

# --------- priors / posteriors --------- #
reconstruction:
  - _target_: lfads_torch.modules.recons.Poisson
variational: True
co_prior:
  _target_: lfads_torch.modules.priors.AutoregressiveMultivariateNormal
  tau: 10.0
  nvar: 0.1
  shape: ${model.co_dim}
ic_prior:
  _target_: lfads_torch.modules.priors.MultivariateNormal
  mean: 0
  variance: 0.1
  shape: ${model.ic_dim}
ic_post_var_min: 1.0e-4

# --------- misc --------- #
dropout_rate: 0.3
cell_clip: 5.0
loss_scale: 1.0e+4
recon_reduce_mean: True

# --------- learning rate --------- #
lr_init: 4.0e-3
lr_stop: 1.0e-5
lr_decay: 0.95
lr_patience: 6
lr_adam_beta1: 0.9
lr_adam_beta2: 0.999
lr_adam_epsilon: 1.0e-8
lr_scheduler: False

# --------- regularization --------- #
weight_decay: 0.0
l2_start_epoch: 0
l2_increase_epoch: 80
l2_ic_enc_scale: 0.0
l2_ci_enc_scale: 0.0
l2_gen_scale: 0.0
l2_con_scale: 0.0
kl_start_epoch: 0
kl_increase_epoch: 80
kl_ic_scale: 1.0e-7
kl_co_scale: 1.0e-7
