# =================================================================================================
# USAGE:
# to start a training in the command line
# python train.py
# =================================================================================================
# Base configuration file for training
# =================================================================================================
defaults:
  - override hydra/launcher: submitit_slurm
# enter your own workspace path here
workspace: runs

out_dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}
start_time: ${now:%Y-%m-%d %H:%M:%S}
hydra:
  sweep:
    dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}/hydra
  run:
    dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}/hydra

  launcher:
    partition: sgpu_medium
    timeout_min: 1440 # max time for the job, 8h: 480, 24h: 1440, 7d: 10080
    gres: gpu:1 # request one GPU for the job


# =================================================================================================
# Job configuration
# =================================================================================================
# name of the job as it appears in the queue
name: sesamo_2x1/seed${train.seed}

# parameters for the training processsqueue 
train:
  n_steps: 6_000 # total number of training steps
  save_interval: 1_000 # save model every n steps
  ess_interval: 100 # compute ESS every n steps
  batch_size: 8_000 # batch size per training step
  lr: 5e-4 # learning rate of the ADAM optimizer
  seed: 42 # random seed for the training
  lr_scheduler_params:
    factor: 0.92
    patience: 2_000
    min_lr: 1e-6
  clip_grad: 1 # clip gradients to this value, set to 0 to disable
  reinforce: True # use reinforce to train the model


# train sampler according to this action: hubbard, scalarphi4, complexphi4, gaussianmixture
action:
  type: hubbard
  hubbard_params:
    u: 18
    beta: 1
    nt: 1
    nx: 2



# init different samplers, set type to null to disable
sampler: 
  lat_shape:
    - 1
    - 2
  dtype: float64 # !!! for hubbard take float64 !!!
  flow:
    - realnvp
    - z2pown_reg
    - hubbard_stochmod

  # parameters for prior gaussian distribution
  prior: gaussian
  gaussian_params:
    mean: 0
    var: 18
    lat_shape: ${sampler.lat_shape}
  realnvp_params:
    num_coupling_layers: 6
    num_hidden_layers: 4
    num_hidden_features: 40
    bias: True
    activation: relu
    lat_shape: ${sampler.lat_shape}
  hubbard_stochmod_params:
    nx: 2