includes:
  -  /home/hul/hul/yl2428/SFM_framework_Dec-9/EScAIP/configs/s2ef/2M/base.yml

trainer: equiformerv2_forces

model:
  name: hydra
  pass_through_head_outputs: True
  otf_graph: True

  backbone:
    model: src.EScAIP.EScAIPBackbone

    # Global Configs
    activation: gelu
    direct_force: true
    hidden_size: 512
    regress_forces: true
    use_fp16_backbone: false
    batch_size: 12 ## need to be the same as optim.batch_size!!
    use_compile: True
    use_padding: True

    # Molecular Graph Configs
    enforce_max_neighbors_strictly: true
    distance_function: gaussian
    max_neighbors: 20
    max_num_elements: 90
    max_num_nodes_per_batch: 150 # Average 73, Max 220, use 150 for padding
    max_radius: 12.0
    otf_graph: true
    use_pbc: true
    use_pbc_single: false


    # Graph Neural Networks Configs
    atom_embedding_size: 128
    atten_name: memory_efficient
    atten_num_heads: 8
    edge_distance_embedding_size: 512
    edge_distance_expansion_size: 600
    node_direction_embedding_size: 64
    node_direction_expansion_size: 10
    num_layers: 10
    output_hidden_layer_multiplier: 2
    readout_hidden_layer_multiplier: 2
    ffn_hidden_layer_multiplier: 2
    use_angle_embedding: true


    # Regularization Configs
    atten_dropout: 0.1
    mlp_dropout: 0.05
    normalization: rmsnorm
    stochastic_depth_prob: 0.0

  heads:
    forces:
      module: src.EScAIP.EScAIPDirectForceHead
    energy:
      module: src.EScAIP.EScAIPEnergyHead


optim:
  batch_size:                   12         # 6
  eval_batch_size:              12         # 6
  load_balancing: atoms
  num_workers: 9
  lr_initial:                   0.0001    # EquivormerV2 uses 0.0004 for signal gpu batch size 8

  optimizer: AdamW
  optimizer_params:
    weight_decay: 0.01
  scheduler: LambdaLR # it need lr_lambda in fairchem-1.2.0, i removed it with dynamic lr_lambda
  scheduler_params:
    lambda_type: cosine
    warmup_factor: 0.2
    warmup_epochs: 0.1
    lr_min_factor: 0.5         # EquivormerV2 uses 0.01

  max_epochs: 30
  clip_grad_norm: 10
  ema_decay: 0.999

  eval_every: 30000
  # checkpoint_evey: 50000

slurm:
  constraint: "volta32gb"

