# Default config file for polymer dynamics (reduced coordinate case)

temperature: 1.0 # non-dim temperature set to be 1.0, varying temperature can be given in args
# dt: 0.0005
dt: 0.0001
# dt: 0.05 # 0.002 * 200 (downsample in the sim.in)

# microscopic_dim: 900 # dof of the polymer
macroscopic_dim: 1 # dimension of the macroscopic state = extension
reduced_dim: 3 # total number of reduced coordinates

model:
  seed: 42 # random seed
  drift:
    type: "mlp" # "onsager" or "mlp"
    # type: "onsager" # "onsager" or "mlp"
    activation: "silu"
    units:
      - 128
      - 128
      - 128
  potential:
    activation: "srequ" # shifted ReQU activation
    alpha: 0.01 # regularization strength
    units: # layer sizes
      - 128
    n_pot: 32
  dissipation:
    activation: "tanh"
    alpha: 0.01
    units:
      - 32
      - 32
  conservation:
    activation: "tanh"
    units:
      - 32
      - 32
  diffusion:
    type: diagonal # (constant -> constant over time), (diagonal -> state-dependent diag), (full -> state-dependent full matrix)
    alpha: 0.01
    activation: "relu"
    units:
      - 32
      - 32
  encoder:
    activation: "tanh"
    units:
      - 16
    mlp_scale: 1.0 # scale of the mlp output
    mlp_input_scale: 0.01 # scale of the input to the mlp
  decoder:
    activation: "tanh"
    units:
      - 16
    mlp_scale: 1.0

train:
  num_epochs_joint: 50 # number of epochs for joint training of encoder/decoder + SDE
  num_epochs: 1000 # number of epochs for training the SDE
  batch_size: 4 # each batch will be [batch_size, train_traj_len, microscopic_dim], so use a small batch size if train_traj_len is large
  train_traj_len: null # can shrink the length of the training trajectories for better GPU performance
  checkpoint_every: 10 # number of epochs to check-point the model
  loss: # loss weights
    recon_weight: 1e-3
    compare_weight: 100.0
  opt: # optimiser options
    learning_rate: 1.0e-3
  rop: # reduce on plateau options
    patience: 20
    cooldown: 20
    factor: 0.4
    rtol: 1e-4
    min_scale: 1e-4
    accumulation_size: 2000

hydra:
  run:
    dir: ./outputs/${now:%Y_%m_%d-%H_%M_%S}
  sweep:
    dir: ./outputs/multirun/${now:%Y_%m_%d-%H_%M_%S}
    subdir: ${hydra.job.num}
