# @package _global_
method_name: framediff
dataset:
  cache_dir: ${hydra:runtime.cwd}/../preprocess/.cache/processed_pdb_60_512
#  cache_dir: ${hydra:runtime.cwd}/../preprocess/.cache/jsonl
  cluster_path: ${hydra:runtime.cwd}/data/clusters-by-entity-30.txt
  min_t: 0.01
  #  samples_per_eval_length: 12
  #  num_eval_lengths: 10
  num_t: 100
  split_ratio: 0.1
  num_eval_lengths: 1
  samples_per_eval_length: 4




frame:
  diffuse_trans: True
  diffuse_rot: True

  # R(3) diffuser arguments
  r3:
    min_b: 0.1
    max_b: 20.0
    coordinate_scaling: 0.1

  # SO(3) diffuser arguments
  so3:
    num_omega: 1000
    num_sigma: 1000
    min_sigma: 0.1
    max_sigma: 1.5
    schedule: logarithmic
    cache_dir: ${hydra:runtime.cwd}/data/framediff/.cache/
    use_cached_score: False


model:


  node_embed_size: 256
  edge_embed_size: 256
  dropout: 0.0
  embed:
    index_embed_size: 32
    aatype_embed_size: 64
    embed_self_conditioning: True
    num_bins: 22
    min_bin: 1e-5
    max_bin: 20.0
  ipa:
    c_s: ${model.node_embed_size}
    c_z: ${model.edge_embed_size}
    c_hidden: 256
    c_skip: 64
    no_heads: 8
    no_qk_points: 8
    no_v_points: 12
    seq_tfmr_num_heads: 4
    seq_tfmr_num_layers: 2
    num_blocks: 4
    coordinate_scaling: ${frame.r3.coordinate_scaling}


experiment:
  # Lightning Trainer Settings
  resume_from_ckpt: False
  resume_ckpt_path: null
  strategy: ddp_find_unused_parameters_true
  use_distributed_sampler: False
  check_val_every_n_epoch: 100
  num_epoch: 200
  use_wandb: False
  ckpt_freq: 5
  lr_scheduler: step
  monitor: train_loss



  # Experiment metadata
  seed: 111
  name: baseline
  run_id: null



  # Training arguments
  log_freq: 1000
  batch_size: 128
  steps_per_epoch: null
  num_loader_workers: 5
  precision: 32

  weight_decay: 0
  max_squared_res: 500000
  prefetch_factor: 100
  use_gpu: False
  train_sample_mode: cluster_time_batch
  # num_gpus: 1



  # Optimizer and Scheduler
  learning_rate: 0.0001

  lr_decay_steps: 10_0000
  lr_decay_min_lr: 0.00008
  lr_decay_rate: 0.95





  # Take early checkpoint at step 100. Helpful for catching eval bugs early.
  early_ckpt: True

  # Checkpoint directory to warm start from.
  warm_start: null
  ckpt_dir: ${hydra:runtime.cwd}/ckpts/warm/best.ckpt

  # Loss weights.
  trans_loss_weight: 1.0
  rot_loss_weight: 0.5
  rot_loss_t_threshold: 0.2
  separate_rot_loss: False
  trans_x0_threshold: 1.0
  coordinate_scaling: ${frame.r3.coordinate_scaling}
  bb_atom_loss_weight: 1.0
  bb_atom_loss_t_filter: 0.25
  dist_mat_loss_weight: 1.0
  dist_mat_loss_t_filter: 0.25
  aux_loss_weight: 0.25

  # Evaluation.
  eval_dir: ./eval_outputs
  noise_scale: 1.0
  eval_sample_mode: length_batch
  # Filled in during training.
  num_parameters: null



inference:
  name: null
  gpu_id: 0  # CUDA GPU to use
  seed: 123



  # Output will be stored in the hydra output subdir
  # For example: lightning/hydra_inference/2024-10-21_12-12-09/framediff_outputs
  output_dir: ./framediff_outputs/

  # ${hydra:runtime.cwd} = protein-se3/lightning
  # Directory of software, weights, and outputs.
  pt_hub_dir: ${hydra:runtime.cwd}/resource/torch/
  # Path to model weights.
  weights_path: ${hydra:runtime.cwd}/resource/framediff/best-epoch=149-train_loss=1.935.ckpt

  diffusion:
    # Number of steps.
    num_t: 500
    # Analogous to sampling temperature.
    noise_scale: 0.1
    # Final t.
    min_t: 0.01

  samples:
    # Number of backbone samples per sequence length.
    samples_per_length: 4
    # Minimum sequence length to sample.
    min_length: 61
    # Maximum sequence length to sample.
    max_length: 320
    # gap between lengths to sample. i.e. this script will sample all lengths
    # in range(min_length, max_length, length_step)
    length_step: 1
