logger: tensorboard
# cuda_device: "0"
normalization: "standardize"
warm_up_period: 2
nblocks: 5

# To specify
seed: 0
ensemble_size: 5 
outer_loop_size: 3 
experiment_name: "rpnn_noshape_ech"
# save_path: "/home/scratch/jiayuc2/fusion_model/rpnn_noshape_ech_synthesized"
data_path: "/home/scratch/jiayuc2/fusion_data/noshape_ech_synthesized"
gt_model_path: "/home/scratch/jiayuc2/fusion_model/rpnn_noshape_ech"

trainer:
    max_epochs: 2 # 1000 # To specify
    gpus: 1

early_stopping:
    patience: 250 # To specify
    min_delta: 0.0
    monitor: 'val/loss'
    mode: 'min'

model:
    _target_: dynamics_toolbox.models.pl_models.sequential_models.rpnn.RPNN
    encode_dim: 512
    rnn_num_layers: 1
    rnn_hidden_size: 256
    learning_rate: 3e-4 # To specify
    weight_decay: 1e-3
    add_mse_to_loss: False # To specify
    mse_wt: 1.0 # To specify
    nll_wt: 1.0 # To specify
    encoder_cfg:
       _target_: dynamics_toolbox.models.pl_models.mlp.MLP
       num_layers: 1
       layer_size: 512
       hidden_activation: "relu"
    pnn_decoder_cfg:
      _target_: dynamics_toolbox.models.pl_models.pnn.PNN
      encoder_output_dim: 128
      logvar_lower_bound: -10
      logvar_upper_bound: 0.5
      encoder_cfg:
         _target_: dynamics_toolbox.models.pl_models.residual_mlp_blocks.ResidualMLPBlocks
         embed_dim: 512
         num_layers_per_block: 1
         num_blocks: ${nblocks}
         hidden_activation: "relu"
      mean_net_cfg:
         _target_: dynamics_toolbox.models.pl_models.mlp.MLP
         num_layers: 0
         layer_size: 128
         hidden_activation: "relu"
      logvar_net_cfg:
         _target_: dynamics_toolbox.models.pl_models.mlp.MLP
         num_layers: 0
         layer_size: 128
         hidden_activation: "relu"

data_module:
  _target_: dynamics.data_modules.KFoldSequenceFusionDataModule
  data_path: ${data_path}
  batch_size: 8 # 1024 # To specify
  n_folds: 10 # To specify
  te_fold: 10 # To specify
  pin_memory: true
  min_shot_amt: 4
  shot_train_length: 225
  num_workers: 8
  train_from_start_only: true
  prop_validation: 0.05 # To specify
  seed: ${seed} 
  bootstrapped: true # To specify

# hydra:
#     run:
#         dir: ${save_path} # os.getcwd()