# Data Loading
data:
  class_path: modelgenerator.rna_ss.rna_ss_data.RNASSPairwiseTokenClassification
  init_args:
    path: null
    num_workers: 4
    pin_memory: true
    persistent_workers: true
    min_seq_len: 0
    max_seq_len: 999999999
    dataset: "bpRNA"

# Model Arguments
model:
  class_path: modelgenerator.rna_ss.rna_ss_task.RNASSPairwiseTokenClassification
  init_args:
    use_legacy_adapter: false
    strict_loading: true
    reset_optimizer_states: false
    backbone_fn:
      class_path: modelgenerator.backbones.aido_rna_1b600m
      init_args:
        max_length: 1024  
        use_peft: false
        config_overwrites:
          hidden_dropout_prob: 0.1
          attention_probs_dropout_prob: 0.1
    adapter_fn:
      class_path: modelgenerator.adapters.ResNet2DAdapter
    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 1e-4
        weight_decay: 0.01
    lr_scheduler:
      class_path: torch.optim.lr_scheduler.LinearLR
      init_args:
        start_factor: 1.0
        end_factor: 0.1
        total_iters: 7500
        verbose: true

# Training Configuration
trainer:
  accelerator: auto
  devices: auto
  max_steps: -1
  max_epochs: 60
  precision: "16-mixed"
  default_root_dir: "/mgen_data/modelgenerator/huggingface_models/rna_ss"
  log_every_n_steps: 50
  detect_anomaly: false
  accumulate_grad_batches: 1
  
  # callbacks
  callbacks:
    # Save a checkpoint for max val acc
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: val/f1
        mode: max
        filename: "ss-epoch={epoch}-valF1={val/f1:.3f}"
        every_n_epochs: 1
        save_top_k: 1
        # save_last: true
        verbose: true
        auto_insert_metric_name: false
    
    # learning rate monitor
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      dict_kwargs:
        logging_interval: "step"
    
    # finetuning scheduler
    - class_path: modelgenerator.callbacks.FTScheduler
      dict_kwargs:
        ft_schedule_path: null
  
  # DDP strategy
  strategy:
    class_path: lightning.pytorch.strategies.DDPStrategy
    dict_kwargs:
      find_unused_parameters: true
