program: main.py
project: GMMDistrax_gridsearch_02_03_LV_and_LD
command:
  - ${env}
  - python
  - ${program}
  - ${args}
  - --use_interpol_gradient
  - --use_normal
  - --gridsearch
method: grid
metric:
  name: "Free_Energy_at_T=1"
  goal: minimize

# name: GMM Bridge LV, use_off_policy, 2k Steps, SDE-learning, n_part 3 SCAN T
name: gridsearch
parameters:
  n_particles:
    value: 50
  Energy_Config:
    value: GMMDistrax
  GPU:
    value: -1
  N_anneal:
    value: 1200
  base_net:
    value: PISgradnet
  SDE_Loss:
    values: [Bridge_rKL_fKL_logderivative]
  SDE_Type:
    value: Bridge_SDE
  SDE_lr:
    values: [0.0001, 0.00005, 0.00001]
  Interpol_lr:
    values: [0.01, 0.001]
  n_eval_samples:
    values: [16000]
  T_end:
    value: 1.
  T_start:
    value: 1.
  batch_size:
    value: 2000
  beta_min:
    value: 0.01
  feature_dim:
    value: 64
  n_hidden:
    value: 64
  n_integration_steps:
    value: 128
  beta_max:
    values: [1., 1.5]
  sigma_init:
    values: [80.]
  model_seeds:
    # values: [1]
    values: [0]



# early_terminate:
#   type: hyperband
#   min_iter: 100



# name: GMM Bridge rKL_LD & LogVar no expl no SDE learning
# method: grid
# project: DDS_GaussianMixtureClass_
# program: main.py
# parameters:
#   SDE_Loss:
#     values: ['Bridge_rKL_logderiv', 'Bridge_LogVarLoss']
#   Energy_Config:
#     value: GaussianMixture
#   n_integration_steps:
#     value: 128
#   T_start:
#     values: [1.]
#   T_end:
#     value: 1.0
#   batch_size:
#     value: 2000
#   lr:
#     value: 0.001
#   Energy_lr:
#     value: 0.0
#   SDE_lr:
#     values: [0]
#   N_anneal:
#     value: 6000
#   feature_dim:
#     value: 64
#   n_hidden:
#     value: 64
#   GPU:
#     value: -1
#   beta_max:
#     value: 0.1
#   beta_min:
#     value: 0.01
#   Network_Type:
#     value: FeedForward
#   SDE_Type:
#     value: Bridge_SDE
#   sigma_init:
#     value: 1.0
#   model_seeds:
#     values: [1,2,3]

# command:
#   - ${env}
#   - python
#   - ${program}
#   - ${args}
#   - --use_interpol_gradient
#   - --use_normal
#   # - --use_off_policy












# name: GMM VP rKL repara & log_der & LogVar
# method: grid
# project: DDS_GaussianMixtureClass_
# program: main.py
# parameters:
#   SDE_Loss:
#     values: ['Reverse_KL_Loss', 'LogVariance_Loss','Reverse_KL_Loss_logderiv']
#   Energy_Config:
#     value: GaussianMixture
#   n_integration_steps:
#     value: 128
#   T_start:
#     values: [1, 3]
#   T_end:
#     value: 1.0
#   batch_size:
#     value: 2000
#   lr:
#     value: 0.001
#   Energy_lr:
#     value: 0.0
#   SDE_lr:
#     value: 0.001
#   N_anneal:
#     value: 6000
#   feature_dim:
#     value: 64
#   n_hidden:
#     value: 64
#   GPU:
#     value: -1
#   beta_max:
#     value: 5
#   beta_min:
#     value: 0.05
#   Network_Type:
#     value: FeedForward
#   SDE_Type:
#     value: VP_SDE
#   sigma_init:
#     value: 1.0
#  # sigma_scale_factor:
#  #   values: [0.,0.15, 1.]

# command:
#   - ${env}
#   - python
#   - ${program}
#   - ${args}
#   - --use_interpol_gradient
#   - --use_normal
#   #- --use_off_policy













# name: "GMM_scaling + sigma_scaling LV & rKL"
# # name: "BS, loss, T=6.1, GMM var 1e-4, Seed 1,2"
# method: grid
# program: main.py
# project: DDS_GaussianMixtureClass_
# parameters:
#   SDE_Loss:
#     values: ['Reverse_KL_Loss','LogVariance_Loss']   #,'LogVariance_Loss','Reverse_KL_Loss'
#   Energy_Config:
#     value: 'GaussianMixture'
#   n_integration_steps:
#     value: 100
#   T_start:
#     values: [1.0]
#   T_end:
#     value: 1.0
#   batch_size:
#     values: [2000] 
#   lr:
#     value: 0.0005
#   N_anneal:
#     values: [300]
#   GPU:
#     value: -1
#   beta_max:
#     value: 1.0
#   beta_min:
#     value: 0.001
#   Network_Type:
#     value: 'FeedForward'
#   SDE_Type:
#     values: ['VP_SDE']
#   SDE_lr:
#     values: [0.0005]
#   Scaling_factor:
#     values: [40, 1]
#   Variances:
#     values: [1. ,0.1]
#   model_seed:
#     values: [1]
#   steps_per_epoch:
#     values: [100]
#   sigma_scale_factor:
#     values: [0, 0.5]

# command:
#   - ${env}
#   - python
#   - ${program}
#   - ${args}
#   - --use_interpol_gradient
#   - --use_normal
#   - --use_off_policy