defaults:
  - _self_
  - wandb: private.yaml
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

# disable hydra outputs
hydra:
  output_subdir: null
  run:
    dir: .

name: "kalmanSSL_figure8" 
method: "kalmanSSL"
method_kwargs:
  state_dim: &state_dim 10
  latent_dim: &latent_dim 5
  input_dim: 100
  hidden_dim: 100
  sigma_D_dependency: "None"
  sigma_A_dependency: "None"

decoder_kwargs:
  observations:
    input_variable: "estimates"
    loss_func: "mse"
    input_dim: *state_dim
    output_dim: 100
    backbone: "mlp"
    backbone_kwargs:
      hidden_dim: 100
      n_hidden_layers: 2
    lr: 0.1
  pos:
    loss_func: "mse"
    input_variable: "estimates"
    input_dim: *state_dim
    output_dim: 2
    backbone: "linear"
    lr: 0.5
  noise_level:
    loss_func: "mse"
    input_variable: "inferences_covariances"
    input_dim: 25
    output_dim: 1
    backbone: "linear"
    lr: 0.5
data:
  dataset: "dot_motion"
  num_workers: 4
  settings:
    motion_type: "figure8"
    num_train_sequences: 800
    num_test_sequences: 8
    seq_len: 100
    image_dim: 10
    dt: 0.1
    period_duration: 1.0
    noise_type: "gaussian"
    observation_noise: 0.01
    prediction_noise: 0.0
optimizer:
  name: "adam"
  batch_size: 8
  lr: 0.005 # 0.001 * 256
  weight_decay: 1e-4
scheduler:
  name: "step"
  lr_decay_steps: [0.33, 0.66]
checkpoint:
  enabled: True
  dir: "trained_models"
  frequency: 1
auto_resume:
  enabled: False

# Trainer settings
max_epochs: 20
val_check_interval: 10
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 32-true # needed for matrix inverse
