# Add global parameters section at the top
globals:
  latent_cond_len: &latent_cond_len 1  # Define anchor with "&" symbol
  use_encoder_prior: &use_encoder_prior False

# Define templates for different parameter types
definitions:
  int_param: &int_param
    type: int
    min: 0  # Default min
    max: 100  # Default max
    step: 1
    log: false

  float_param: &float_param
    type: float
    min: 0.0  # Default min
    max: 1.0  # Default max
    log: false
    step: null  # Default no step (continuous)

  categorical_param: &categorical_param
    type: categorical
    choices: []  # Default empty choices list


##############################################
##############################################
##############################################
##############################################
##############################################
##############################################
##############################################
##############################################
##############################################

dataset:
  nice_name: Harmonic Oscillator
  name: harmonic_oscillator2
  dataset_size: 10_000
  dim: 1
  train_proportion: 0.8
  val_proportion: 0.1
  train_batch_size: 512
  gradient_accumulation_batch_size_multiplier: 1 # 256 effective training batch size
  val_batch_size: 512

  seq_length: 15 # Higher number means more densely sampled data
  pred_length: 14
  data_latent_sigma: 0.10000000149011612 # process_noise for the data generating SDE

  latent_sigma: 0.1 # process_noise for our model's latent SDE
  noise_std: 0.01 # obs_noise for our model's latent SDE

  metric_to_compute:
    - crps
    - nll
    - nrmse

  evaluation_settings:
    - future_latent
    - future_observation

# dataset:
#   nice_name: Harmonic Oscillator
#   name: harmonic_oscillator
#   dataset_size: 10_000
#   dim: 1
#   train_proportion: 0.8
#   val_proportion: 0.1
#   train_batch_size: 512
#   gradient_accumulation_batch_size_multiplier: 1 # 256 effective training batch size
#   val_batch_size: 512

#   seq_length: 15 # Higher number means more densely sampled data
#   pred_length: 14
#   data_latent_sigma: 0.10000000149011612 # process_noise for the data generating SDE
#   data_noise_std: 0.01 # obs_noise for the data generating SDE

#   latent_sigma: 0.1 # process_noise for our model's latent SDE
#   noise_std: 0.01 # obs_noise for our model's latent SDE

#   metric_to_compute:
#     - crps
#     - nll
#     - nrmse

#   evaluation_settings:
#     - future_latent
#     - future_observation

##############################################

my_autoregressive_reparam_rnn_bwd:

  objective: mse
  model_type: MyReparameterizedAutoregressiveRNNBwdModel

  model:
    hidden_size: 64
    parametrization: mixed
    predict_cov: False

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-3
    max_train_steps: 5000
    warmup_steps: 100

my_neural_sde_rnn_bwd:

  objective: mse
  model_type: MyNeuralSDERNNBwd

  model:
    hidden_size: 64
    parametrization: mixed
    predict_cov: False

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-3
    max_train_steps: 5000
    warmup_steps: 100
