trainer: equiformerv2_forces

dataset:
  train:
    format: oc22_lmdb
    src: data/oc22/s2ef/train
    key_mapping:
      y: energy
      force: forces
    transforms:
      normalizer:
        energy:
          mean: 0.0
          stdev: 5
        forces:
          mean: 0.0
          stdev: 0.14759646356105804 #0.522
    lin_ref:  configs/oc22/linref/oc22_linfit_coeffs.npz
  val:
    src: data/oc22/s2ef/val_id
    lin_ref:  configs/oc22/linref/oc22_linfit_coeffs.npz

logger: wandb

outputs:
  energy:
    shape: 1
    level: system
    prediction_dtype: float32
  forces:
    irrep_dim: 1
    level: atom
    train_on_free_atoms: True
    eval_on_free_atoms: True
    prediction_dtype: float32

loss_functions:
  - energy:
      fn: mae
      coefficient: 4
  - forces:
      fn: l2mae
      coefficient: 100

evaluation_metrics:
  metrics:
    energy:
      - mae
    forces:
      - mae
      - cosine_similarity
      - magnitude_error
    misc:
      - energy_forces_within_threshold
  primary_metric: energy_mae

model:
  name: hydra
  pass_through_head_outputs: True
  otf_graph: True

  backbone:
    model: src.E2Former.E2FormerBackbone

  # global_cfg:
    regress_forces: true
    direct_force: false
    hidden_size: 384
    batch_size: 12
    activation: gelu
    use_padding: true
    use_fp16_backbone: false
    use_compile: false

  # molecular_graph_cfg:
    use_pbc: true
    use_pbc_single: false
    otf_graph: true
    max_neighbors: 20
    max_num_elements: 90
    max_num_nodes_per_batch: 150 # Average 73, Max 220, use 150 for padding
    enforce_max_neighbors_strictly: false
    distance_function: gaussian

  # gnn_cfg:
    atom_embedding_size: 128
    node_direction_embedding_size: 64
    node_direction_expansion_size: 64
    edge_distance_expansion_size: 128
    edge_distance_embedding_size: 64
    atten_name: flash
    atten_num_heads: 4
    readout_hidden_layer_multiplier: 2
    output_hidden_layer_multiplier: 2
    ffn_hidden_layer_multiplier: 4
    use_angle_embedding: true

  # reg_cfg:
    mlp_dropout: 0.1
    atten_dropout: 0.1
    stochastic_depth_prob: 0.1
    normalization: layernorm

  # psm_config:
    encoder_embed_dim: 384
    num_atom_features: 64
    num_timesteps: 100
    embedding_dim: 384
    diffusion_time_step_encoder_type: sinusoidal
    use_2d_bond_features: true
    use_charge_embedding: true
    pbc_cutoff: 5.0
    pbc_expanded_token_cutoff: 256
    pbc_expanded_num_cell_per_direction: 4
    pbc_multigraph_cutoff: 12.0
    node_type_edge_method: NON_EXCHANGABLE
    AutoGradForce: true
    pbc_use_local_attention: true
    use_2d_atom_features: true
    mlm_from_decoder_feature: true
    num_encoder_layers: 6
    num_attention_heads: 8
    num_3d_bias_kernel: 128
    num_edges: 512
    edge_hidden_dim: 64
    num_spatial: 10
    multi_hop_max_dist: 510

 
  # backbone_config:
    irreps_node_embedding: "128x0e+128x1e+128x2e"
    num_layers: 8
    # max_neighbors: 20
    pbc_max_radius: 12.0
    max_radius: 15.0
    basis_type: "gaussian"
    number_of_basis: 128
    num_attn_heads: 4
    attn_scalar_head: 32
    irreps_head: "32x0e+32x1e+32x2e"
    rescale_degree: False
    nonlinear_message: False
    norm_layer: "rms_norm_sh_BL"
    alpha_drop: 0.1
    proj_drop: 0.0
    out_drop: 0.0
    drop_path_rate: 0.0
    attn_type: "first-order"
    tp_type : 'QK_alpha'
    edge_embedtype: "eqv2"
    attn_biastype: "share"
    ffn_type: 'default'
    add_rope: False
    time_embed: False
    sparse_attn: False
    dynamic_sparse_attn_threthod: 1000
    force_head: null

  heads:
    forces:
      module: src.E2Former.E2FormerDirectForceHead
    energy:
      module: src.E2Former.E2FormerEnergyHead

optim:
  batch_size:               4
  eval_batch_size:          4
  load_balancing:           atoms
  num_workers:              8
  lr_initial:               0.0002

  optimizer:                AdamW
  optimizer_params:
    weight_decay:           0.001
  scheduler:                LambdaLR
  scheduler_params:
    lambda_type:            cosine
    warmup_factor:          0.2
    warmup_epochs:          0.1
    lr_min_factor:          0.01

  max_epochs:               6
  clip_grad_norm:           50
  ema_decay:                0.999

  eval_every:               5000