defaults:
  - _self_
  - data: multibench
  - model: comm

seed: 42

mode: "train"

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: "auto" # Accelerator given to pytorch-lightning Trainer (eg `cpu` or `gpu`)
  strategy: "ddp_find_unused_parameters_true"
  devices: "auto"
  num_nodes: 1 # Number of distributed nodes
  max_epochs: 100
  default_root_dir: "/home/user/results/align_or_not"
  use_distributed_sampler: false
  deterministic: false
  inference_mode: false # avoid weird bugs during linear probing

linear_probing:
  _target_: evaluation.linear_probe.LinearProbingCallback
  use_sklearn: false
  fastsearch: false
  logging_level: "INFO"
  frequency: "by_epoch"


optim:
  lr: 1e-3
  weight_decay: 1e-2

# Define encoders/modalities to load MultiBench data
visionandtouch: # Image + Force to predict proprioception (regression task)
  modalities:
    - "image"
    - "force"
  task: "ee_yaw_next"
  kwargs: # augmentations to apply
    augmentations:
      - "simclr"
      - "noise+drop"
  encoders:
    - _target_: timm.create_model # Image encoder
      model_name: resnet18
      pretrained: false
      num_classes: 0
    - _target_: models.robotics.ForceEncoder
      z_dim: 128

  # CLIP projectors
  clip_projection1:
    _target_: torch.nn.Linear
    in_features: 512
    out_features: 128
    bias: False

  clip_projection2:
    _target_: torch.nn.Linear
    in_features: 128
    out_features: 128
    bias: False

  # SSL projectors
  projection_head1:
    _target_: models.siamese.Siamese._build_mlp
    in_dim: 512 # Visual encoder dimension
    mlp_dim: 1024 # Hidden dim of MLP projection head
    out_dim: 256 # Output embed dim of MLP projection head

  projection_head2:
    _target_: models.siamese.Siamese._build_mlp
    in_dim: 128 # Visual encoder dimension
    mlp_dim: 1024 # Hidden dim of MLP projection head
    out_dim: 256 # Output embed dim of MLP projection head

  adapters:
    - _target_: models.input_adapters.PatchedInputAdapter
      num_channels: 512 # nb of feature maps
      stride_level: 1
      patch_size_full: 1
      dim_tokens: 128
      image_size: 4 # size of feature maps
    - _target_: models.input_adapters.FeaturesInputAdapter
      n_features: 128
      dim_tokens: 128

visionandtouch-bin: # Image + Proprioception to predict contact (binary task)
  modalities:
    - "image"
    - "proprio"
  task: "contact_next"
  kwargs: # augmentations to apply
    augmentations:
      - "simclr"
      - "noise"
  encoders:
    - _target_: timm.create_model # Image encoder
      model_name: resnet18
      pretrained: false
      num_classes: 0
    - _target_: models.robotics.ProprioEncoder
      z_dim: 512

  # CLIP projectors
  clip_projection1:
    _target_: torch.nn.Linear
    in_features: 512
    out_features: 256
    bias: False

  clip_projection2:
    _target_: torch.nn.Linear
    in_features: 512
    out_features: 256
    bias: False

  # SSL projectors
  projection_head1:
    _target_: models.siamese.Siamese._build_mlp
    in_dim: 512 # Visual encoder dimension
    mlp_dim: 1024 # Hidden dim of MLP projection head
    out_dim: 256 # Output embed dim of MLP projection head

  projection_head2:
    _target_: models.siamese.Siamese._build_mlp
    in_dim: 512 # Visual encoder dimension
    mlp_dim: 1024 # Hidden dim of MLP projection head
    out_dim: 256 # Output embed dim of MLP projection head

  adapters:
    - _target_: models.input_adapters.PatchedInputAdapter
      num_channels: 512 # nb of feature maps
      stride_level: 1
      patch_size_full: 1
      dim_tokens: 512
      image_size: 4 # size of feature maps
    - _target_: models.input_adapters.SimpleFeaturesInputAdapter


mimic:
  modalities:
      - "tabular"
      - "timeseries"
  task: 7
  kwargs: # augmentations to apply
    augmentations:
      - "noise"
      - "drop+noise"
  encoders:
    - _target_: models.mlp.MLP # Encoder of static patient info
      indim: 5
      hiddim: 10
      outdim: 10
      dropout: false
    - _target_: models.gru.GRU # Encoder of health recordings
      indim: 12
      hiddim: 512
      dropout: false
      batch_first: true
  adapters:
    - _target_: models.input_adapters.FeaturesInputAdapter
      n_features: 10
      dim_tokens: 512
    - null

mosi:
  modalities:
    - "vision"
    - "text"
  task: "classification"
  kwargs: # augmentations to apply
    augmentations: "drop+noise"
  encoders:
    - _target_: models.transformer.Transformer
      n_features: 20 # !! according to MultiBench but not in FactorCL
      dim: 40
      max_seq_length: 50
      positional_encoding: false
    - _target_: models.transformer.Transformer
      n_features: 300
      dim: 40 # !! according to FactorCL paper but not in FactorCL implementation (==600)
      max_seq_length: 50
      positional_encoding: false
  adapters: # not required
    - null
    - null

mosei:
  modalities:
    - "vision"
    - "text"
  task: "classification"
  kwargs: # augmentations to apply
    augmentations: "permute+noise+drop"
  encoders:
    - _target_: models.transformer.Transformer
      n_features: 35
      dim: 40
      max_seq_length: 50
      positional_encoding: false
    - _target_: models.transformer.Transformer
      n_features: 300
      dim: 40
      max_seq_length: 50
      positional_encoding: false
  adapters: # not required
    - null
    - null

humor:
  modalities:
    - "vision"
    - "text"
  task: "classification"
  kwargs: # augmentations to apply
    augmentations: "drop+noise"
  encoders:
    - _target_: models.transformer.Transformer
      n_features: 371
      dim: 40
      max_seq_length: 50
      positional_encoding: false
    - _target_: models.transformer.Transformer
      n_features: 300
      dim: 40
      max_seq_length: 50
      positional_encoding: false
  adapters: # not required
    - null
    - null

sarcasm:
  modalities:
    - "vision"
    - "text"
  task: "classification"
  kwargs: # augmentations to apply
    augmentations: "drop+noise"
  encoders:
    - _target_: models.transformer.Transformer
      n_features: 371
      dim: 40
      max_seq_length: 50
      positional_encoding: false
    - _target_: models.transformer.Transformer
      n_features: 300
      dim: 40
      max_seq_length: 50
      positional_encoding: false
  adapters: # not required
    - null
    - null