# @package _global_
method_name: foldflow
dataset:
  cache_dir: ${hydra:runtime.cwd}/../preprocess/.cache/processed_pdb_60_512
#  cache_dir: ${hydra:runtime.cwd}/../preprocess/.cache/jsonl
  cluster_path: ${hydra:runtime.cwd}/data/clusters-by-entity-30.txt
  filtering:
    max_len: 512
    min_len: 60
    # Selects a subset of examples. Useful for debugging.
    subset: null
    allowed_oligomer: [monomeric]
    max_helix_percent: 1.0
    max_loop_percent: 0.5
    min_beta_percent: -1.0
    rog_quantile: 0.96
  min_t: 0.01
  num_eval_lengths: 1
  samples_per_eval_length: 10

  num_t: 100
  max_same_res: 50 # the number of pdb with the same number of residue to use to compute the ot plan.
  num_csv_processors: 5
  cache_full_dataset: False
  split_ratio: 0.1


flow_matcher:
  flow_trans: True
  flow_rot: True
  ot_fn: exact
  reg: 0.05 # only used if ot_fn is 'sinkhorn'.
  ot_plan: True # Using OT plan to pair the noise with data. Default False.
  stochastic_paths: False # Switches to stochastic

  # R(3) Flow Matcher arguments
  r3:
    min_b: 0.01
    min_sigma: 0.01
    max_b: 20.0
    coordinate_scaling: 0.1
    g: 0.1

  # SO(3) Flow Matcher arguments
  so3:
    min_sigma: 0.01
    max_sigma: 1.5
    axis_angle: True
    inference_scaling: 10
    g: 0.1



model:


  node_embed_size: 256
  edge_embed_size: 128
  dropout: 0.0
  embed:
    index_embed_size: 32
    aatype_embed_size: 64
    embed_self_conditioning: True
    num_bins: 22
    min_bin: 1e-5
    max_bin: 20.0
  ipa:
    c_s: ${model.node_embed_size}
    c_z: ${model.edge_embed_size}
    c_hidden: 256
    c_skip: 64
    no_heads: 8
    no_qk_points: 8
    no_v_points: 12
    seq_tfmr_num_heads: 4
    seq_tfmr_num_layers: 2
    num_blocks: 4
    coordinate_scaling: ${flow_matcher.r3.coordinate_scaling}
    p_uncond: 0.2
  axis_angle: ${flow_matcher.so3.axis_angle}

experiment:
  # Lightning Trainer Settings
  resume_from_ckpt: False
  resume_ckpt_path: null
  strategy: ddp
  use_distributed_sampler: False
  num_epoch: 100
  check_val_every_n_epoch: 10
  use_wandb: False
  ckpt_freq: 5
  lr_scheduler: step
  monitor: train_loss


  # Experiment metadata
  seed: 111
  name: baseline
  run_id: null

  #training mode
  use_ddp: False

  # Training arguments
  log_freq: 100
  batch_size: 128
  eval_batch_size: ${dataset.samples_per_eval_length}
  num_loader_workers: 12
  torch_num_threads: 8
  precision: 32



  # Optimizer and Scheduler
  learning_rate: 0.0001
  lr_decay_steps: 10_0000
  lr_decay_min_lr: 0.00008
  lr_decay_rate: 0.95
  max_squared_res: 500000
  prefetch_factor: 100
  use_gpu: True
  sample_mode: cluster_time_batch





  # Take early checkpoint at step 100. Helpful for catching eval bugs early.
  early_ckpt: True

  # Checkpoint directory to warm start from.
  # if warm_start is "auto" then checks the dir for any checkpoints
  warm_start: null
  use_warm_start_conf: False
  ckpt_dir: ./ckpt/
  full_ckpt_dir: ${experiment.ckpt_dir}/${experiment.name}/

  # Loss weights.
  trans_loss_weight: 1.0
  rot_loss_weight: 0.5
  rot_loss_t_threshold: 0.0
  separate_rot_loss: True
  trans_x0_threshold: 0.0
  coordinate_scaling: ${flow_matcher.r3.coordinate_scaling}
  bb_atom_loss_weight: 1.0
  bb_atom_loss_t_filter: 0.25
  dist_mat_loss_weight: 1.0
  dist_mat_loss_t_filter: 0.25
  aux_loss_weight: 0.25

  # Evaluation.
  eval_dir: ./eval_outputs
  noise_scale: 1.0
  # Filled in during training.
  num_parameters: null



inference:
  name: null
  gpu_id: 0  # CUDA GPU to use
  seed: 123


  # Directory of software, weights, and outputs.
  pt_hub_dir: ${hydra:runtime.cwd}/resource/torch/
  output_dir: ./foldflow_outputs

  # Path to model weights.
  weights_path: ${hydra:runtime.cwd}/resource/foldflow/best-epoch=79-ot.ckpt
  flow:
    # Number of steps.
    num_t: 100
    # Analogous to sampling temperature.
    noise_scale: 0.1
    # Final t.
    min_t: 0.01

  samples:
    # Number of backbone samples per sequence length.
    samples_per_length: 10
    # Number of ESMFdold samples per backbone sample.
    seq_per_sample: 8
    # Minimum sequence length to sample.
    min_length: 100
    # Maximum sequence length to sample.
    max_length: 200
    # gap between lengths to sample. i.e. this script will sample all lengths
    # in range(min_length, max_length, length_step)
    length_step: 100