defaults:
  - _self_
  - wandb: private.yaml
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

# disable hydra outputs
hydra:
  output_subdir: null
  run:
    dir: .

name: "kalmanSSL_jump" 
method: "kalmanSSL"
method_kwargs:
  state_dim: &state_dim 2
  latent_dim: &latent_dim 2
  input_dim: 100
  hidden_dim: 100
  sigma_D_dependency: "None"
  sigma_A_dependency: "observation+prediction" #"error"
  sigma_A_pred_steps: 4
decoder_kwargs:
  observations:
    loss_func: "mse"
    input_variable: "estimates"
    input_dim: *state_dim
    output_dim: 100
    backbone: "mlp"
    backbone_kwargs:
      hidden_dim: 100
      n_hidden_layers: 2
    lr: 0.1
  pos:
    loss_func: "mse"
    input_variable: "estimates"
    input_dim: *state_dim
    output_dim: 2
    backbone: "linear"
    lr: 0.5
  noise_level:
    loss_func: "mse"
    input_variable: "inferences_covariances"
    input_dim: 4
    output_dim: 1
    backbone: "linear"
    lr: 0.5
  jumps:
    loss_func: "cross_entropy"
    input_variable: "prediction_covariances"
    input_dim: 4 
    output_dim: 2
    backbone: "mlp"
    backbone_kwargs:
      hidden_dim: 10
      n_hidden_layers: 1
    lr: 0.5
data:
  dataset: "dot_motion"
  num_workers: 4
  settings:
    motion_type: "circular"
    num_train_sequences: 6400
    num_test_sequences: 8
    seq_len: 100
    image_dim: 10
    dt: 0.1
    period_duration: 2.0
    noise_type: "gaussian"
    observation_noise: 0.15
    prediction_noise: 0.0
    jump_rate: 0.01
optimizer:
  name: "adam"
  batch_size: 32
  lr: 0.01 # 0.001 * 256
  weight_decay: 1e-4
scheduler:
  # name: "exponential"
  # min_lr: 0.00001 # 0.00006 * 256
  name: "step"
  lr_decay_steps: [0.33, 0.66]
checkpoint:
  enabled: True
  dir: "trained_models"
  frequency: 1
auto_resume:
  enabled: False

# Trainer settings
max_epochs: 15
val_check_interval: 10
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 32-true # needed for matrix inverse
