# See help here on how to configure hyper-parameters with config files:
# https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
seed_everything: 42
trainer:
  accelerator: gpu
  strategy: auto
  devices: 1
  num_nodes: 1
  precision: 32-true
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      # NOTE: The logs will be saved in `save_dir/lightning_logs/version`
      save_dir: dynaclr
      version: time_interval_1
      log_graph: True
  callbacks:
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: step
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: loss/val
        every_n_epochs: 1
        save_top_k: 4
        save_last: true
  fast_dev_run: false
  max_epochs: 100
  log_every_n_steps: 10
  enable_checkpointing: true
  inference_mode: true
  use_distributed_sampler: true
  sync_batchnorm: true
model:
  encoder:
    class_path: viscy.representation.contrastive.ContrastiveEncoder
    init_args:
      backbone: convnext_tiny
      in_channels: 1
      in_stack_depth: 1
      stem_kernel_size: [1, 4, 4]
      stem_stride: [1, 4, 4]
      embedding_dim: 768
      projection_dim: 32
      drop_path_rate: 0.0
  loss_function:
    class_path: torch.nn.TripletMarginLoss
    init_args:
      margin: 0.5
  lr: 0.0002
  log_batches_per_epoch: 3
  log_samples_per_batch: 2
  example_input_array_shape: [1, 1, 1, 128, 128]
data:
  # FIXME: Change the path to downloaded data
  data_path: ../../Hela_CTC.zarr
  # NOTE: same as above for this example
  tracks_path: ../../Hela_CTC.zarr
  source_channel:
    - DIC
  z_range: [0, 1]
  # NOTE: reduce this if running out of memory
  batch_size: 16
  num_workers: 4
  initial_yx_patch_size: [256, 256]
  final_yx_patch_size: [128, 128]
  time_interval: 1
  normalizations:
    - class_path: viscy.transforms.NormalizeSampled
      init_args:
        keys: [DIC]
        level: fov_statistics
        subtrahend: mean
        divisor: std
  augmentations:
    - class_path: viscy.transforms.RandAffined
      init_args:
        keys: [DIC]
        prob: 0.8
        scale_range: [0, 0.2, 0.2]
        rotate_range: [3.14, 0.0, 0.0]
        shear_range: [0.0, 0.01, 0.01]
        padding_mode: zeros
    - class_path: viscy.transforms.RandAdjustContrastd
      init_args:
        keys: [DIC]
        prob: 0.5
        gamma: [0.8, 1.2]
    - class_path: viscy.transforms.RandScaleIntensityd
      init_args:
        keys: [DIC]
        prob: 0.5
        factors: 0.5
    - class_path: viscy.transforms.RandGaussianSmoothd
      init_args:
        keys: [DIC]
        prob: 0.5
        sigma_x: [0.25, 0.75]
        sigma_y: [0.25, 0.75]
        sigma_z: [0.0, 0.0]
    - class_path: viscy.transforms.RandGaussianNoised
      init_args:
        keys: [DIC]
        prob: 0.5
        mean: 0.0
        std: 0.2
