#########################################################################################
# Wave function
#########################################################################################
wave_function:
  embedding:
    module: moon # ['moon', 'psiformer', 'ferminet']
    args:
      moon:
        dim: 256
        n_layer: 4
        embedding_dim: 256
        edge_embedding: 32
        edge_hidden_dim: 8
        edge_rbf: 6
        activation: silu
      psiformer:
        dim: 256
        embedding_dim: 256
        n_head: 4
        n_layer: 4
        activation: silu
      ferminet:
        embedding_dim: 256
        hidden_dims:
          - [256, 32]
          - [256, 32]
          - [256, 32]
          - [256, 32]
        activation: silu

  envelope:
    module: efficient # ['efficient', 'full']
    args:
      efficient:
        env_per_nuc: 8
      full: {}

  orbitals:
    module: pfaffian # ['pfaffian', 'slater']
    args:
      pfaffian:
        determinants: 16
        orb_per_charge:
          "1": 2
          "2": 2
          "3": 8
          "4": 8
          "5": 8
          "6": 8
          "7": 8
          "8": 8
          "9": 8
          "10": 8
        hf_match_steps: 50
        hf_match_lr: 1
        hf_match_orbitals: 1
        hf_match_antisymmetrizer: 1
        hf_match_ema: 0.999
        hf_match_init_bias: 0.0
        hf_match_noise_std: 0.0
      lowrankpfaffian:
        determinants: 16
        rank: 4
        orb_per_charge:
          "1": 2
          "2": 2
          "3": 16
          "4": 16
          "5": 16
          "6": 16
          "7": 16
          "8": 16
          "9": 16
          "10": 16
        hf_match_steps: 50
        hf_match_lr: 1
        hf_match_orbitals: 1
        hf_match_antisymmetrizer: 1
        hf_match_ema: 0.999
        hf_match_init_bias: 0.0
      slater:
        determinants: 16

  jastrows: # ['mlp', 'cusp']
    - - mlp
      - hidden_dims: [128, 32]
        activation: silu
    - - cusp
      - {}

  meta_network:
    module: meta_gnn # ['meta_gnn', null]
    args:
      meta_gnn:
        message_dim: 32
        embedding_dim: 64
        num_layers: 3
        activation: silu
        n_rbf: 6

#########################################################################################
# VMC
#########################################################################################
vmc:
  epochs: 80_000
  batch_size: 512 # by default full batch
  max_consecutive_fails: 5
  max_total_rollbacks: 20
  reweighting_mode: "overlap_mean_only" # True -> use reweighting for everything; `overlap` uses reweighting only for overlap; False -> no reweighting
  determinant_regularization: 0.0
  normalizer_regularization: 0.0
  profiling:
    output_dir: null # defaults to <log_dir>/profiles when not provided
    trace:
      enabled: false
      start_step: null # optional 0-indexed step to begin tracing
      duration_steps: 1
      steps: [] # optional explicit list of 0-indexed steps to trace
      subdir: trace

  thermalizing_epochs: 10

  mcmc:
    steps: 20
    init_width: 0.1
    window_size: 20
    target_pmove: 0.525
    error: 0.025
    blocks: 1
    nonlocal_steps: 0
    nonlocal_step_width: 2.0
    langevin_steps: 0
    langevin_init_width: 1.0

  preconditioner:
    module: spring # ['identity', 'cg', 'spring']
    args:
      spring:
        damping: 1.e-3
        decay_factor: 0.99
        aux_grad_damping: 0.0
        aux_grad_cutoff: 0.0
        aux_grad_global_damping: 1.e-3
        cutoff_to_zero: True
        dtype: float64
        clip_eigenvals: 1.e-6
      identity: {}
      cg:
        damping: 1.e-3
        decay_factor: 0.99
        maxiter: 100
        precondition_aux_grads: False
      kfac:
        damping_schedule: { schedule: constant_schedule, value: 0.001 }
        ema: 0.95
        lr_schedule: [hyperbolic_decay, 0.1, 6000]
        norm_constraint: 1.e-3
        norm_constraint_decay: 0.0 # c * (1 - decay_factor)
        decay_factor: 0
        fisher_reweighting: False
        dtype: float32
        kronecker_rank: 1
        num_iters: 5
      realkfac:
        lr_schedule: [hyperbolic_decay, 0.05, 10_000]
        damping: 1.e-3
        norm_constraint: 1.e-3
        ema: 0.95
        apply_mask_to_fisher: False

  # The following optimizer setup is typically used with SPRING or other preconditioners
  optimizer:
    - transform: scale_by_hyperbolic_schedule
      learning_rate: 0.02
      delay: 10_000
    - transform: clip_by_global_norm
      max_norm: 0.032 # sqrt(1.e-3)

  # For KFAC we need a different optimizer
  # optimizer:
  #   - transform: sgd
  #     learning_rate: 1.0

  clipping:
    module: quantile # ['none', 'mean', 'median']
    args:
      none: {}
      mean:
        max_deviation: 5
      median:
        max_deviation: 5
      quantile:
        max_deviation: 5
        quantile: 0.95

  masking:
    module: quantile
    args:
      quantile:
        max_deviation: 10
        quantile: 0.95
      iterative_mean:
        max_deviation: 10
        iterations: 1

  spin_penalty:
    enabled: False
    penalty_scale: 8.0
    # penalty_scale:
    #   schedule: sigmoid
    #   init_value: 0.0
    #   end_value: 0.001
    #   midpoint_step: 40_000
    #   transition_steps: 10_000
    max_grad_norm: 10.0
    penalty_type: "minimize" # ["minimize", "snap"]
    decay: 0.999

    masking:
      module: quantile # ['none', 'iterative_mean']
      args:
        none: {}
        iterative_mean:
          max_deviation: 10
          iterations: 1
        quantile:
          max_deviation: 10
          quantile: 0.95

    clipping:
      module: quantile # ['none', 'mean', 'median']
      args:
        none: {}
        mean:
          max_deviation: 5
        median:
          max_deviation: 5
        quantile:
          max_deviation: 5
          quantile: 0.95

  state_overlap:
    dtype: float64
    penalty_scale: 4.0 # global scale of penalty

    clipping:
      module: quantile # ['none', 'mean', 'median']
      args:
        none: {}
        mean:
          max_deviation: 5
        median:
          max_deviation: 5
        quantile:
          max_deviation: 5
          quantile: 0.95

    masking:
      module: none
      args:
        none: {}
        quantile:
          max_deviation: 10
          quantile: 0.95

    scaler:
      module: diff-std
      args:
        none: {}
        diff-std:
          decay_schedule:
            schedule: clipped_hyperbolic_growth
            init_value: 0.5
            delay: 10
            bound: 0.999
          min_scale_factor: 0.001
          max_scale_factor: 5.0
          asym_scale: 1000
          asym_strategy: "step" # ["none", "sigmoid", "softplus"]
        direct-energy:
          decay_schedule:
            schedule: clipped_hyperbolic_growth
            init_value: 0.5
            delay: 10
            bound: 0.999
          min_scale_factor: 0.001
          max_scale_factor: 10.0
          energy_shift: 1.0

#########################################################################################
# Pretraining
#########################################################################################
pretraining:
  epochs: 2_0
  batch_size: 512 # by default full batch
  basis: aug-cc-pVTZ
  hf_config:
    n_tuple_excitations: 2
    exclude_core: False
    extra_overlap_excitations: 4
    hf_method: rohf # ['rhf', 'uhf', 'rohf']
    ordered_excitations: True
    use_smearing: False
    minimal_spin_only: False

  mcmc:
    hf_fraction: 0.0
    steps: 20
    init_width: 0.02
    window_size: 20
    target_pmove: 0.525
    error: 0.025
    blocks: 1
    nonlocal_steps: 0
    nonlocal_step_width: 2.0
    langevin_steps: 0
    langevin_init_width: 0.5

  optimizer:
    - transform: clip_by_global_norm
      max_norm: 1.0
    - transform: scale_by_adam
    - transform: filter_by_param
      name: kernel
      transformations: ["scale_by_trust_ratio"]
    - transform: filter_by_param
      name: embedding
      transformations: ["scale_by_trust_ratio_embeddings"]
    - transform: scale_by_hyperbolic_schedule
      learning_rate: 0.001
      delay: 1000

  reparam_loss_scale: 1.e-6

#########################################################################################
# Logging
#########################################################################################
logging: {}
  # wandb: {}
  # file:
  #   base_dir: checkpoints
  #   save_interval: 1000
  #   max_num_checkpoints: 5

#########################################################################################
# Evaluation
#########################################################################################
evaluation:
  thermalizing_epochs: 500
  # total no. of eval epochs; -1 runs until total_samples_per_energy is reached
  epochs: -1
  # total no. of walkers; -1 means use all from VMC optimization
  num_total_walker: -1
  # total no. of samples for each energy estimate = walker_per_mol * epochs
  total_samples_per_energy: 1_000_000
  mcmc_steps: 100
