ckpt_path: null

# Data Loading
data:
  class_path: modelgenerator.data.RNAMeanRibosomeLoadDataModule
  init_args:
    path: "genbio-ai/rna-downstream-tasks"
    config_name: "mean_ribosome_load"
    batch_size: 64
    num_workers: 32
    pin_memory: true
    persistent_workers: true

# Model Arguments
model:
  class_path: modelgenerator.tasks.SequenceRegressionWithScaling
  init_args:
    use_legacy_adapter: false
    strict_loading: true
    reset_optimizer_states: false
    backbone:
      class_path: modelgenerator.backbones.aido_rna_1b600m
      init_args:
        # from_scratch: false
        max_length: 1024
        use_peft: false
        config_overwrites:
          hidden_dropout_prob: 0.1
          attention_probs_dropout_prob: 0.1
        # model_init_args: null
    adapter:
      class_path: modelgenerator.adapters.ResNet1DAdapter
      init_args:
        channels: 256
        num_blocks: 9
        dropout: 0.1
    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 1e-4
        weight_decay: 0.01
    lr_scheduler:
      class_path: torch.optim.lr_scheduler.LinearLR
      init_args:
        start_factor: 1.0
        end_factor: 0.1
        total_iters: 5000
        verbose: true

# Training Configuration
trainer:
  accelerator: auto
  devices: 1
  max_steps: -1
  max_epochs: 70
  precision: "32"
  default_root_dir: "/mgen_data/modelgenerator/huggingface_models/rna_mrl"
  log_every_n_steps: 50
  detect_anomaly: true
  accumulate_grad_batches: 1
  
  # callbacks
  callbacks:
    # Save a checkpoint for max val acc
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      dict_kwargs:
        monitor: val_r2
        mode: max
        filename: "mrl-epoch={epoch}--tr_mse={train_mse:.2f}-val_mse={val_mse:.2f}-val_r2={val_r2:.3f}"
        every_n_epochs: 1
        save_top_k: 1
        save_last: true
        verbose: true
        auto_insert_metric_name: false
    
    # learning rate monitor
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      dict_kwargs:
        logging_interval: "step"
    
    # finetuning scheduler
    - class_path: modelgenerator.callbacks.FTScheduler
      dict_kwargs:
        ft_schedule_path: null
  
  # DDP strategy
  strategy:
    class_path: lightning.pytorch.strategies.DDPStrategy
    dict_kwargs:
      find_unused_parameters: true
  