seed: 42
mode: "train"


trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: "auto" # Accelerator given to pytorch-lightning Trainer (eg `cpu` or `gpu`)
  strategy: "ddp"
  devices: "auto"
  num_nodes: 1 # Number of distributed nodes
  max_epochs: 100
  default_root_dir: "/home/user/results/align_or_not"
  use_distributed_sampler: false
  deterministic: false
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 1
  inference_mode: false # avoid weird bugs during linear probing

optim:
  lr: 3e-4
  weight_decay: 1e-4


# Define default visual encoders for bimodal trifeatures dataset
model:
  # Default backbones
  encoders:
    - _target_: models.alexnet.AlexNetEncoder # small AlexNet encoder
      latent_dim: 256
      dropout: 0.5
    - _target_: models.alexnet.AlexNetEncoder # small AlexNet encoder
      latent_dim: 256
      dropout: 0.5

  # SSL projection head
  visual_projection:
    _target_: models.siamese.Siamese._build_mlp
    in_dim: 256 # Visual encoder dimension
    mlp_dim: 1024 # Hidden dim of MLP projection head
    out_dim: 256 # Output embed dim of MLP projection head

 # CLIP projectors
  clip_image_projection:
    _target_: torch.nn.Linear
    in_features: 256 # Visual encoder dimension
    out_features: 512 # Output dimension
    bias: False

  clip_text_projection:
    _target_: torch.nn.Linear
    in_features: 256
    out_features: 512
    bias: False

  # MMFusion adapters
  adapters:
    - _target_: models.input_adapters.PatchedInputAdapter
      num_channels: 256 # nb of feature maps
      stride_level: 1
      patch_size_full: 1
      dim_tokens: 512
      image_size: 6 # size of feature maps
    - _target_: models.input_adapters.PatchedInputAdapter
      num_channels: 256 # nb of feature maps
      stride_level: 1
      patch_size_full: 1
      dim_tokens: 512
      image_size: 6 # size of feature maps