defaults:
  - _self_
  - wandb: private.yaml
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

# disable hydra outputs
hydra:
  output_subdir: null
  run:
    dir: .


name: "mixtureRPL_random_square" 
method: "mixtureRPL"
method_kwargs:
  state_dim: &latent_state_dim 2
  input_dim: &image_dim 100
  rnn_hidden_dim: &hidden_dim 10
  mixture_dim: 2
  rnn_num_layers: 1
  encoder_hidden_dim: 100
decoder_kwargs:
  observations:
    loss_func: "mse"
    input_variable: "estimates"
    input_dim: *hidden_dim
    output_dim: *image_dim
    backbone: "mlp"
    backbone_kwargs:
      hidden_dim: 100
      n_hidden_layers: 2
    lr: 0.1
  pos:
    loss_func: "mse"
    input_variable: "estimates"
    input_dim: *hidden_dim
    output_dim: 2
    backbone: "linear"
    lr: 0.5
  noise_level:
    loss_func: "mse"
    input_variable: "estimates"
    input_dim: *hidden_dim
    output_dim: 1
    backbone: "linear"
    lr: 0.5
data:
  dataset: "dot_motion"
  num_workers: 4
  settings:
    motion_type: "random_square"
    num_train_sequences: 800
    num_test_sequences: 8
    seq_len: 100
    image_dim: 10
    dt: 0.1
    period_duration: 0.1
    noise_type: "gaussian"
    observation_noise: 0.01
    prediction_noise: 0.0
optimizer:
  name: "adam"
  batch_size: 8
  lr: 0.01 # 0.001 * 256
  weight_decay: 1e-6
scheduler:
  name: "linear"
  min_lr: 0.00001 # 0.00006 * 256
checkpoint:
  enabled: True
  dir: "trained_models"
  frequency: 1
auto_resume:
  enabled: False

# Trainer settings
max_epochs: 20
val_check_interval: 10
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 32-true # needed for matrix inverse
