# =================================================================================================
# USAGE:
# to start a training in the command line
# python train.py
# =================================================================================================
# Base configuration file for training
# =================================================================================================
defaults:
  - override hydra/launcher: submitit_slurm

# enter your own workspace path here
workspace: runs

out_dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}
start_time: ${now:%Y-%m-%d %H:%M:%S}

hydra:
  sweep:
    dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}/hydra
  run:
    dir: ${workspace}/${action.type}/${now:%Y-%m-%d}/${name}/hydra

  launcher:
    partition: mlgpu_medium
    timeout_min: 1440 # max time for the job, 8h: 480, 24h: 1440, 7d: 10080
    gres: gpu:1 # request one GPU for the job


# =================================================================================================
# Job configuration
# =================================================================================================
# name of the job as it appears in the queue
name: test/01

# parameters for the training processsqueue 
train:
  n_steps: 10_000 # total number of training steps
  save_interval: 1_000 # save model every n steps
  ess_interval: 100 # compute ESS every n steps
  batch_size: 8_000 # batch size per training step
  lr: 5e-4 # learning rate of the ADAM optimizer
  seed: 42 # random seed for the training
  lr_scheduler_params:
    factor: 0.92
    patience: 2_000
    min_lr: 1e-6
  clip_grad: 1 # clip gradients to this value, set to 0 to disable
  reinforce: False


# train sampler according to this action: hubbard, scalarphi4, complexphi4, gaussianmixture
action:
  type: gaussianmixture
  hubbard_params:
    u: 18
    beta: 1
  gaussianmixture_params:
    n_gaussians: 8
    radius: 12
    broken: 0.05



# init different samplers, set type to null to disable
sampler: 
  lat_shape:
    - 1
    - 2
  dtype: float64 # !!! for hubbard take float64 !!!
  flow:
    - realnvp

  # parameters for prior gaussian distribution
  prior: gaussian
  uniform_params:
    low: -1
    high: 1
    lat_shape: ${sampler.lat_shape}
  gaussian_params:
    mean: 0
    var: 20
    lat_shape: ${sampler.lat_shape}
  realnvp_params:
    ncouplings: 6
    nblocks: 4
    mid_dim: 40
    bias: True
    activation: relu
    lat_shape: ${sampler.lat_shape}
  brokenz2_stochmod_params:
    flip_direction: [1]
  vmonf_params:
    ncouplings: 6
    nblocks: 4
    mid_dim: 40
    bias: True
    activation: tanh
    lat_shape: ${sampler.lat_shape}
    sectors: 4