# @package _global_
defaults:
  - /base/logging_config@_here_
  - /base/exp_config@_here_

exp_config:
  name: internlm_transitions
  experiment: internlm_transitions

logging_config:
  project: internlm_transitions
  offline: false

model:
  _target_: models.end_to_end.tactic_models.error_pred.batched_tac_embed.BatchTacEmbed
  config:
    load_ckpt: false
    error_weight: 1
    time_weight: 1
    # for biased error/success distribution
    label_weights: [ 0.5, 0.5 ]
    tac_encoder: kaiyuy/leandojo-lean4-tacgen-byt5-small
    goal_encoder: kaiyuy/leandojo-lean4-tacgen-byt5-small
    decoder: kaiyuy/leandojo-lean4-tacgen-byt5-small
    max_length: 2500
    num_samples: 4
    lr: 5e-5
    warmup_steps: 200

data_module:
  _target_: models.end_to_end.tactic_models.error_pred.datamodule_separate_tac.SeparateTacModule
  model_name: kaiyuy/leandojo-lean4-tacgen-byt5-small
  batch_size: 1 # effective_batch_size == batch_size * accumulate_grad_batches * devices
  eval_batch_size: 1
  max_seq_len: 1650
  num_workers: 0
  database: lean_e2e
  collection: 'intern_lm_transitions'
  trace_files: null # e.g. runs/internlm_bestfs/traces/0
  replace: 'keep' # add/keep/drop files wrt the existing collection


trainer:
  accelerator: gpu
  devices: 1
  #  num_nodes: 2
  precision: bf16-mixed
  strategy:
    _target_: lightning.pytorch.strategies.DeepSpeedStrategy
    stage: 1
    offload_optimizer: false
    cpu_checkpointing: false
    logging_batch_size_per_gpu: 1

  gradient_clip_val: 1.0

#  val_check_interval: 100
#  limit_val_batches: 100


  callbacks:
    - _target_: lightning.pytorch.callbacks.LearningRateMonitor
      logging_interval: step
    - _target_: lightning.pytorch.callbacks.ModelCheckpoint
      verbose: true
      save_top_k: 3
      save_last: true
      monitor: top4_acc_val
      mode: max
      dirpath: ${exp_config.directory}/checkpoints
      auto_insert_metric_name: true
      filename: "{epoch}-{step}-{top4_acc_val:.2f}"
      enable_version_counter: false

