# =================================================================================================
# USAGE:
# to start a training in the command line
# python train.py
# =================================================================================================
# Base configuration file for training
# =================================================================================================
defaults:
  - override hydra/launcher: submitit_slurm
# enter your own workspace path here
workspace: runs

out_dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}
start_time: ${now:%Y-%m-%d %H:%M:%S}
hydra:
  sweep:
    dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}/hydra
  run:
    dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}/hydra

  launcher:
    partition: mlgpu_medium
    timeout_min: 1440 # max time for the job, 8h: 480, 24h: 1440, 7d: 10080
    gres: gpu:1 # request one GPU for the job


# =================================================================================================
# Job configuration
# =================================================================================================
# name of the job as it appears in the queue
name: sesamo/seed${train.seed}

# parameters for the training processsqueue 
train:
  n_steps: 400_000 # total number of training steps
  save_interval: 50_000 # save model every n steps
  ess_interval: 1_000 # compute ESS every n steps
  batch_size: 8_000 # batch size per training step
  lr: 5e-4 # learning rate of the ADAM optimizer
  seed: 42 # random seed for the training
  lr_scheduler_params:
    factor: 0.92
    patience: 2_000
    min_lr: 1e-6
  clip_grad: 1 # clip gradients to this value, set to 0 to disable
  reinforce: True # use reinforce to train the model


# train sampler according to this action: hubbard, scalarphi4, complexphi4, gaussianmixture
action:
  type: scalarphi4
  scalarphi4_params:
    lambd: 0.022
    kappa: 0.3
    broken: 0.005



# init different samplers, set type to null to disable
sampler: 
  lat_shape:
    - 16
    - 8
  dtype: float32 # !!! for hubbard take float64 !!!
  flow:
    - realnvp
    - z2_reg
    - broken_z2_stochmod

  # parameters for prior gaussian distribution
  prior: gaussian
  gaussian_params:
    mean: 0
    var: 1
    lat_shape: ${sampler.lat_shape}
  realnvp_params:
    num_coupling_layers: 6
    num_hidden_layers: 4
    num_hidden_features: 100
    bias: True
    activation: relu
    lat_shape: ${sampler.lat_shape}