defaults:
  - model: nlb_mc_maze
  - datamodule: nlb_mc_maze
  - callbacks: default
  - _self_

seed: 0
ignore_warnings: True

trainer:
  _target_: pytorch_lightning.Trainer
  gradient_clip_val: 200
  max_epochs: 1_000
  log_every_n_steps: 5
  # Prevent console output by individual models
  enable_progress_bar: False
  enable_model_summary: False

# `callbacks` and `logger` must be composed as dicts, then passed to `trainer` as lists
callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: valid/recon_smth
    mode: min
    save_top_k: 1
    save_last: True
    verbose: False
    dirpath: lightning_checkpoints
    auto_insert_metric_name: False
  early_stopping:
    _target_: lfads_torch.extensions.tune.EarlyStoppingWithBurnInPeriod
    monitor: valid/recon_smth
    mode: min
    patience: 200
    min_delta: 0
    burn_in_period: ${max:${sum:${model.l2_start_epoch},${model.l2_increase_epoch}},${sum:${model.kl_start_epoch},${model.kl_increase_epoch}}}
  tune_report_callback:
    _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback
    metrics:
      - valid/recon_smth
      - cur_epoch
    "on": validation_end
  learning_rate_monitor:
    _target_: pytorch_lightning.callbacks.LearningRateMonitor
    logging_interval: epoch

logger:
  csv_logger:
    _target_: pytorch_lightning.loggers.CSVLogger
    save_dir: "csv_logs"
    version: ""
    name: ""
  tensorboard_logger:
    _target_: pytorch_lightning.loggers.TensorBoardLogger
    # Configure logger so PTL and ray use the same TB file.
    save_dir:
      _target_: ray.tune.get_trial_dir
    version: "."
    name: ""
  wandb_logger:
    _target_: pytorch_lightning.loggers.WandbLogger
    project: null # Overwrite with PROJECT_STR in the run script
    tags:
      - multi
      - null # Overwrite with DATASET_STR in the run script
      - null # Overwrite with RUN_TAG in the run script

posterior_sampling:
  use_best_ckpt: True
  fn:
    _target_: lfads_torch.post_run.analysis.run_posterior_sampling
    filename: lfads_output.h5
    num_samples: 50
