data:
  split: CPSea
  valid_num: 100
  # CSV for path and metadata to training examples.
  csv_path: # path to metadata
  # cluster_path: /mnt/shared_storage/processed_pdbs/clusters-by-entity-30.txt
  filtering:
    max_len: 512
    min_len: 0
    # Selects a subset of examples. Useful for debugging.
    subset: null
  min_t: 0.01
  num_eval_lengths: 10
  num_t: 100
  max_same_res: 50 # the number of pdb with the same number of residue to use to compute the ot plan.
  num_csv_processors: 5
  cache_full_dataset: False  # Cache both to disk (LMDB) and in memory.
  cache_dataset_in_memory: False # If True load from mem. If False, load from disk (LMDB).
  cache_path: ./cache/  # Where to save the LMDB cache.

experiment:
  # Experiment metadata
  name: baseline
  run_id: null

  #training mode
  use_ddp : True

  # Training arguments
  log_freq: 100
  batch_size: 256
  eval_batch_size: ${experiment.batch_size}
  num_loader_workers: 12
  torch_num_threads: 8
  num_epoch: 100
  learning_rate: 2e-4
  max_squared_res: 500000
  prefetch_factor: 100
  use_gpu: True
  num_gpus: 8
  sample_mode: default

  # Loss weights.
  trans_loss_weight: 1.0
  rot_loss_weight: 0.5
  rot_loss_t_threshold: 0.0
  separate_rot_loss: True
  trans_x0_threshold: 0.0
  coordinate_scaling: ${flow_matcher.r3.coordinate_scaling}
  bb_atom_loss_weight: 1.0
  bb_atom_loss_t_filter: 0.25
  dist_mat_loss_weight: 1.0
  dist_mat_loss_t_filter: 0.25
  aux_loss_weight: 0.25

  # Checkpoint directory to warm start from.
  base_dir: ./saved_model
  ckpt_dir: ckpt/
  full_ckpt_dir: ${experiment.ckpt_dir}

  # Evaluation.
  eval_dir: eval_outputs
  noise_scale: 1.0

  # Filled in during training.
  num_parameters: null

flow_matcher:
  flow_trans: True
  flow_rot: True
  ot_fn: exact
  reg: 0.05 # only used if ot_fn is 'sinkhorn'.
  ot_plan: False # Using OT plan to pair the noise with data. Default False.
  stochastic_paths: False # Switches to stochastic

  # R(3) Flow Matcher arguments
  r3:
    min_b: 0.01
    min_sigma: 0.01
    max_b: 20.0
    coordinate_scaling: 0.1
    g: 0.1

  # SO(3) Flow Matcher arguments
  so3:
    min_sigma: 0.01
    max_sigma: 1.5
    axis_angle: True
    inference_scaling: -0.01
    g: 0.1

model:
  model_name: "ff2"
  esm2_model_key: "esm2_650M" # Trained with "esm2_650M"
  scaffold_training: False
  binder_training: False
  binder_percent_fix_structure: 1.0
  bb_encoder:
    num_blocks: 2 
    coordinate_scaling: ${flow_matcher.r3.coordinate_scaling}
  bb_decoder:
    num_blocks: 2
    coordinate_scaling: ${flow_matcher.r3.coordinate_scaling}
  seq_emb_to_block:
    single_dim: 128 
    pair_dim: 128
  representation_combiner:
    single_dim: 128 # NOTE: If proj+concat, the total dim will be 512
    pair_dim: 64 # NOTE: If proj+concat, the total dim will be 512
    layer_norm: True
  modalities_transformer:
    trunk_type: "transformer"
    num_blocks: 2
    sequence_head_width: 32
    pairwise_head_width: 32
    chunk_size: null # null won't chunk. Lower chunk_size reduce memory, but reduces speed.
  p_mask_sequence: 0.5

  embed:
    embed_self_conditioning: True
    use_alphafold_position_embedding: False
    relpos_k: null