# @package _global_
name: sweeps_sde4

hydra:
  mode: MULTIRUN
  sweeper:
    params:
      method: score
     method/sampler: em_gaus_auto_full
      task: periodic_sde
      task.num_simulations: 10000
      eval: swd
      method.neural_net.hidden_dim: 'choice(50, 100, 200)'
      method.neural_net.time_embedding_dim: 'choice(8,16)'
      method.neural_net.num_hidden: 'choice(3,4,5,6)'
      method.neural_net.activation: 'choice(jax.nn.gelu, jax.nn.relu)'
      method.neural_net.layer_norm: 'choice(True, False)'
      method.neural_net.skip_connection: 'choice(True, False)'
      method.params_train.learning_rate: 'choice(1e-3, 5e-4, 1e-4)'
      method.params_train.ema: 'choice(False, True)'
      method.params_train.ema_decay: 'choice(0.99, 0.999,0.9999)'
      method.params_train.clip_max_norm: 'choice(1.,10.,100.)'
      method.params_train.num_epochs: 'choice(10,20,30)'
      method.params_train.num_inner_epochs: 'choice(100,500,1000,2000)'
      method.params_train.optimizer: 'choice(optax.adam, optax.adamw)'
      method.params_train.scheduler: 'choice(cosine, constant)'

    
  run:
    dir: results/${name}
  sweep:
    dir: results/${name}
    subdir: ${hydra.job.override_dirname}


defaults:
  - _self_
  - override /partition: gpu
  - override /sweeper: tpe