seed: 0
device: None                              # Set to 'cpu' or 'cuda' to override automatic device selection
proj_name: sj-pdes-exp

hydra:
  run:
    dir: outputs/${dataset.name}/${seed}-${nef.invariant_type}-${now:%Y-%m-%d-%H-%M-%S}}/

logging:
  log_dir: ''                            # Use optuna's log directory if empty

  log_every_n_steps: 50
  visualize_every_n_steps: 5000

  checkpoint_every_n_epochs: 50
  keep_n_checkpoints: 1
  checkpoint: False
  debug: False

# Dataset configuration
dataset:
  name: diff_sphere
  batch_size: 2
  traj_len_train: 10
  traj_len_out_horizon: 10
  path: 'data/'
  num_signals_train: 512
  num_signals_test: 128
  num_workers: 0
  image_shape: -1

# Neural field backbone configuration
nef:
  num_in: -1                              # Automatically derived from the dataset
  num_out: -1                             # Automatically derived from the dataset

  num_layers: 0
  num_hidden: 16
  num_heads: 2
  condition_value_transform: True
  condition_invariant_embedding: True

  latent_dim: 4
  num_latents: 18
  gaussian_window: -1                          # Value is calculated from number of latents, if None gaussian window is not used, if -1 the value is set proportional to the number of latents.
  optimize_gaussian_window: False              # If True, the gaussian window size is optimized during training
  use_gaussian_window: False

  embedding_type: rff                         # Choices, 'rff', 'polynomial', 'ffn'
  embedding_freq_multiplier_invariant: 0.01     # For RFF the 1/std and for polynomial the degree
  embedding_freq_multiplier_value: 0.01        # For RFF the 1/std and for polynomial the degree
  invariant_type: polar_periodic                       # Choices: 'rel_pos', 'norm_rel_pos', 'abs_pos', 'ponita', 'rel_pos_periodic', 'polar_periodic'

node:
  name: ponita
  num_layers: 3
  num_hidden: 32
  widening_factor: 2
  kernel_size: "global"
  degree: 3
  basis_dim: 32

  # Sets the ode solver parameters
  dt: 1
  method: 'euler'

# Training configuration
training:
  num_epochs: 750
  max_num_sampled_points: 8192

  ode:
    train_from_epoch: 150
    train_until_epoch: 9999
  nef:
    train_from_epoch: 0
    fit_on_num_steps: 4  # Number of steps on which to fit the neural field backbone
    train_until_epoch: 150

# Testing configuration
test:
  test_interval: 250
  test_dp_interval: 50000
  test_equiv_at_epoch: 500

# Meta learning configuration
meta:
  meta_sgd: True
  num_inner_steps: 3

  inner_learning_rate_p: 1.0
  inner_learning_rate_a: 5.0
  inner_learning_rate_window: 0.0
  learning_rate_meta_sgd: 1e-4

  noise_pos_inner_loop: 0.0

# Optimizer configuration
optimizer:
  name: adamw
  learning_rate_snef: 1e-4
  learning_rate_codes: 0.0
  learning_rate_ode: 1e-3
