seed: 0
device: None
proj_name: sj-pdes-exp

hydra:
  run:
    dir: outputs/${dataset.name}/${seed}-${nef.invariant_type}-${now:%Y-%m-%d-%H-%M-%S}}/

logging:
  log_dir: ''

  log_every_n_steps: 50
  visualize_every_n_steps: 5000

  checkpoint_every_n_epochs: 50
  keep_n_checkpoints: 1
  checkpoint: True
  debug: False

# Dataset configuration
dataset:
  name: diffusion_plane
  batch_size: 8
  traj_len_train: 10
  traj_len_out_horizon: 10
  path: 'data/'
  num_signals_train: 2048
  num_signals_test: 32
  num_workers: 0
  image_shape: -1

# Neural field backbone configuration
nef:
  # Values derived from the dataset
  num_in: -1
  num_out: -1

  num_layers: 0
  num_hidden: 64
  num_heads: 2
  condition_value_transform: True
  condition_invariant_embedding: True

  latent_dim: 16
  num_latents: 4
  # Value calculate from number of latents
  gaussian_window: -1
  optimize_gaussian_window: False
  use_gaussian_window: True

  # Choices 'rff', 'polynomial', 'ffn'
  embedding_type: rff
  embedding_freq_multiplier_invariant: 0.05
  embedding_freq_multiplier_value: 0.01
  invariant_type: ponita

node:
  name: ponita
  num_layers: 3
  num_hidden: 64
  widening_factor: 2
  kernel_size: "global"
  degree: 3
  basis_dim: 64

  # Sets the ode solver parameters
  dt: 1
  method: 'euler'

# Training configuration
training:
  num_epochs: 1000
  max_num_sampled_points: 1024

  ode:
    train_from_epoch: 100
    train_until_epoch: 10000
  nef:
    train_from_epoch: 0
    fit_on_num_steps: 4  # Number of steps on which to fit the neural field backbone
    train_until_epoch: 100

# Testing configuration
test:
  test_interval: 100
  test_dp_interval: 100
  test_equiv_at_epoch: 200

# Meta learning configuration
meta:
  meta_sgd: True
  num_inner_steps: 3

  inner_learning_rate_p: 1.0
  inner_learning_rate_a: 5.0
  inner_learning_rate_window: 0.0
  learning_rate_meta_sgd: 1e-4

  noise_pos_inner_loop: 0.0

# Optimizer configuration
optimizer:
  name: adamw
  learning_rate_snef: 1e-4
  learning_rate_codes: 0.0
  learning_rate_ode: 1e-3
