# @package _global_

# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
  - _self_
  - paths: default.yaml
  - extras: default.yaml
  - hydra: default.yaml
  - experiment: null # needs command line override
  - debug: null
  - optional local: default

model:
  optimizer:
    lr: 0.0001
    weight_decay: 0.0001

  predictor:
    n_tokens: 20 #20 for gfp
    kernel_size: 5
    input_size: 256
    dropout: 0.0
    name: CNN
    activation: relu
    linear: True
    seq_len: ${experiment.seq_len} 

trainer:
  default_root_dir: ${paths.output_dir}
  min_epochs: 1 # prevents early stopping
  max_epochs: 2000
  accelerator: gpu
  log_every_n_steps: 1
  check_val_every_n_epoch: 1
  deterministic: False

data:
  task: ${experiment.task}  # GFP or folding or AAV
  seed: 420
  batch_size: 1024
  pin_memory: False
  num_workers: 0
  encoding: onehot
  alphabet: ARNDCQEGHILKMFPSTWYV # AA alphabet
  val_samples: 200
  sequence_column: sequence
  

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    dirpath: ./ckpt/${data.task}
    filename: epoch_{epoch:03d}
    monitor: val/spearmanr
    mode: max
    save_last: True # additionally always save an exact copy of the last checkpoint to a file last.ckpt
    save_top_k: 3 # save k best models (determined by above metric)
    auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name
    save_weights_only: False # if True, then only the model’s weights will be saved
    every_n_train_steps: null # number of training steps between checkpoints
    train_time_interval: null # checkpoints are monitored at the specified time interval
    every_n_epochs: null # number of epochs between checkpoints
    save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation

wandb:
  _target_: pytorch_lightning.loggers.wandb.WandbLogger
  name: null
  save_dir: ${paths.output_dir}
  offline: True
  project: ${data.task}
  log_model: False # upload lightning ckpts

model_checkpoint_dir: null
preprocessed_data_path: null
num_gpus: 1

run_name: null
debug: False

# task name, determines output directory path
task_name: "train_predictor"
tags: ['dev', 'latest']

ckpt_path: null
seed: null

