# Add global parameters section at the top
globals:
  latent_cond_len: &latent_cond_len 2  # Define anchor with "&" symbol
  use_encoder_prior: &use_encoder_prior False

# Define templates for different parameter types
definitions:
  int_param: &int_param
    type: int
    min: 0  # Default min
    max: 100  # Default max
    step: 1
    log: false

  float_param: &float_param
    type: float
    min: 0.0  # Default min
    max: 1.0  # Default max
    log: false
    step: null  # Default no step (continuous)

  categorical_param: &categorical_param
    type: categorical
    choices: []  # Default empty choices list


##############################################
##############################################
##############################################
##############################################
##############################################
##############################################
##############################################
##############################################
##############################################


dataset:
  nice_name: Brusselator
  name: brusselator
  dataset_size: 2400
  dim: 2
  seq_length: 150
  pred_length: 75
  train_proportion: 0.666
  val_proportion: 0.166
  train_batch_size: 256
  gradient_accumulation_batch_size_multiplier: 1 # 256 effective training batch size
  val_batch_size: 256

  noise_std: 0.3
  tracking_sigma: 0.1
  time_scale_mult: 8.0

  hyper_params:
    train_batch_size:
      optuna_hyper_param:
        <<: *int_param
        min_val: 16
        max_val: 1024
        step: 16

  metric_to_compute:
    - crps
    - nll
    - nrmse

  evaluation_settings:
    - future_latent
    - future_observation
    - future_denoised_observation

##############################################

my_autoregressive_reparam_rnn:

  objective: mse
  model_type: MyReparameterizedAutoregressiveRNNModel

  model:
    hidden_size: 128
    parametrization: mixed
    predict_cov: False

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

true_baseline_autoregressive_rnn:

  objective: ml
  model_type: MyBaselineAutoregressiveRNNModel

  model:
    hidden_size: 128
    parametrization: std
    predict_cov: True

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

baseline_autoregressive_rnn:

  objective: ml
  model_type: MyBaselineAutoregressiveRNNModel

  model:
    hidden_size: 128
    parametrization: std
    predict_cov: True

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

my_diffusion_model_rnn:

  objective: flow_matching
  model_type: MyDiffusionRNNModel

  model:
    hidden_size: 128

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

my_non_probabilistic_rnn:

  objective: mse
  model_type: MyNonProbabilisticRNNModel

  model:
    hidden_size: 128
    parametrization: mixed
    predict_cov: False

  latent_cond_len: *latent_cond_len
  use_encoder_prior: True # Hack to make things stable

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

baseline_diffusion_model_rnn:

  objective: flow_matching
  model_type: MyBaselineDiffusionRNNModel

  model:
    hidden_size: 128

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

my_neural_sde_rnn:

  objective: drift_matching
  model_type: MyNeuralSDERNN

  model:
    hidden_size: 128
    predict_flow_or_drift: drift

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000

my_neural_ode_rnn:

  objective: flow_matching
  model_type: MyNeuralSDERNN

  model:
    hidden_size: 128
    predict_flow_or_drift: flow

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-4
    max_train_steps: 300_000
    warmup_steps: 1000


##############################################

my_autoregressive_reparam_rnn_bwd:

  objective: mse
  model_type: MyReparameterizedAutoregressiveRNNBwdModel

  model:
    hidden_size: 128
    parametrization: mixed
    predict_cov: False

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-3
    max_train_steps: 30_000
    warmup_steps: 100

my_neural_sde_rnn_bwd:

  objective: mse
  model_type: MyNeuralSDERNNBwd

  model:
    hidden_size: 128
    parametrization: mixed
    predict_cov: False

  latent_cond_len: *latent_cond_len
  use_encoder_prior: *use_encoder_prior

  optimizer:
    lr: 1.0e-3
    max_train_steps: 30_000
    warmup_steps: 100
