# @package _global_

add_to_name: run
name: math_equations_${operation}_${mod_prime}_${add_to_name}
num_classes: ${mod_prime}

operation: ModDivisionDataset
mod_prime: 97
num_labels: ${mod_prime}
vocab_size: 100

alternating_reg_loss: False
num_workers: 10
data:
  _target_: llid.data.math_equations.MathDataModule
  dataset:
    _target_: llid.data.math_equations.${operation}
    p: ${mod_prime} 
    frac_train: 0.5
    fix_val_pct: -1
  train_batch_size: 512
  val_batch_size: 512
  train_shuffle: True
  num_workers: ${num_workers}
  pin_memory: True
  persistent_workers: False

model: 
  _target_: transformers.GPT2ForSequenceClassification
  config: 
    _target_: transformers.GPT2Config
    n_embd: 128
    n_layer: 2
    n_head: 4
    vocab_size: ${vocab_size} # TODO: dont hardcode
    num_labels: ${num_labels}
    pad_token_id: ${vocab_size}

initial_lr: 1e-3
no_lr_scheduler:
  _target_: torch.optim.lr_scheduler.ConstantLR
  factor: 1

training_loop:
  _target_: llid.training_loops.sequence_classification.SequenceClassificationTL
  _recursive_: False
  config:
    warmup_steps: 10000

    metrics: ${metrics}
    epoch_metrics: ${epoch_metrics}
    regs: ${regs}
    
    optimizer: ${optimizer}

    model: ${model}

    alternating_reg_loss: ${alternating_reg_loss}

    lr: ${initial_lr}
    lr_scheduler: ${no_lr_scheduler}

epoch_metrics:
  _target_: llid.utils.metrics.EpochMetricsList
  _args_:
    - _target_: llid.utils.metrics.NoEpochMetric

metrics:
  _target_: llid.utils.metrics.MetricsList
  _args_:
    - _target_: llid.utils.metrics.SMAccuracy

regs:
  _target_: llid.utils.regs.NoRegularization
  lamb: 1

weight_decay: 0
optimizer:
  _target_: torch.optim.Adam
  betas: [0.9, 0.98]
  lr: ${initial_lr}
  weight_decay: ${weight_decay}