# @package _global_

defaults:
  - _self_
  - data: tofu # base dataset
  - data_mode: forgetonly # training data mode
  - model_train: grad_ascent # model and its loss

# Training arguments
# batch_size: 16
batch_size: 8
gradient_accumulation_steps: 2
num_epochs: 10
lr: 2e-3

# Project arguments
project: ???
name: null #! set at runtime
resume: false
resume_from_checkpoint: null
debug: false
seed: 42
postfix: ""
base_logdir: null #! set at runtime

BASELOGDIR: "hf-outputs_lightning_tune"
OUTPUTMODELDIR: "trained_models"

lightning:
  logger:
    wandb: 
      target: "pytorch_lighting.loggers.WandbLogger"
      params:
        project: ${project}
        name: null          # set at runtime
        save_dir: null      # set at runtime
        offline: ${debug}

  callbacks:
    checkpoint_callback:
      params:
        dirpath: null
        filename: "{epoch:02}-{step:04}"
        verbose: true
        save_last: false # by default, don't save las
        save_top_k: -1  # by default, save all checkpoints
        every_n_epochs: 1  # by defcault, save every checkpoint
        monitor: null  # by default, no monitor
        save_weights_only: true

  trainer:
    accelerator: gpu
    devices: [5, 6]
    strategy: ddp
    log_every_n_steps: 1 # this is global step
    # precision: bf16-true
    precision: bf16-mixed
    max_epochs: ${...num_epochs}
    check_val_every_n_epoch: 1
    accumulate_grad_batches: ${gradient_accumulation_steps}
    gradient_clip_val: 1.
    gradient_clip_algorithm: norm
    benchmark: false

hydra:
  run:
    dir: ${BASELOGDIR}/${hydra.job.name}-${data.split}/Lorar=${model_train.Lora.r}-lr=${model_train.learning_rate}-layer=${model_train.num_layer}/${hydra.job.override_dirname}/${now:%Y-%m-%d_%H-%M-%S}
  job:
    config:
      override_dirname:
        exclude_keys: 
          - save_dir
          - project
          - name
          - lightning.trainer.devices
          - model_train.learning_rate
          - model_train.num_layer
          - model_train.Lora.r
          - model_train.Lora.alpha
          - model_train.remember_weight
          - model_train.loss_type
          - model_train.model_path
          - BASELOGDIR
          - OUTPUTMODELDIR
          - lr
          - data.split
          - lightning.trainer.strategy

  callbacks:
    rewritejobdircallback:
      _target_: src.hydra_callbacks.RewriteJobDirCallback
  
BASEDIR: ./
