defaults:
  - esm2

name: ESM2_T48
model_name_or_path: esm2_t48_15B_UR50D
batch_size: 4

peft:
  r: 2
  alpha: 32
  dropout: 0.5

prediction_head:
  input_dim: 5120

search_space:
  model.learning_rate: tag(log, interval(1e-6, 1e-3))
  model.weight_decay: tag(log, interval(1e-4, 1e-1))
  model.batch_size: choice(2, 4, 8, 16)
  model.loss: choice(mse, huber, l1)