# @package _global_
defaults:
  - /base/logging_config@_here_
  - /base/exp_config@_here_

exp_config:
  name: combined_transition_model
  experiment: error_pred

logging_config:
  project: error_pred
  offline: false

model:
  _target_: models.end_to_end.tactic_models.error_pred.model.ErrorPredModel
  config:
    load_ckpt: false
    error_weight: 1
    time_weight: 1
    # for biased error/success distribution
    label_weights: [ 0.25, 0.75 ]
    tac_encoder: pretrained_tactic_generator_path
    goal_encoder: pretrained_tactic_generator_path
    decoder: pretrained_tactic_generator_path
    max_length: 2500
    num_samples: 4
    lr: 5e-5
    warmup_steps: 200

data_module:
  _target_: models.end_to_end.tactic_models.error_pred.datamodule.ErrorPredDataModule
  model_name: pretrained_tactic_generator_path
  batch_size: 1 # effective_batch_size == batch_size * accumulate_grad_batches * devices
  eval_batch_size: 1
  max_seq_len: 1650
  num_workers: 0
  collection: 'minif2f_transitions'
  trace_files: trace_dir  # add path to the trace files for the desired proof attempt
  replace: 'keep' # add/keep/drop files wrt the existing collection


trainer:
  accelerator: gpu
  devices: 1
  #  num_nodes: 2
  precision: bf16-mixed
  strategy:
    _target_: lightning.pytorch.strategies.DeepSpeedStrategy
    stage: 1
    offload_optimizer: false
    cpu_checkpointing: false
    logging_batch_size_per_gpu: 1

  gradient_clip_val: 1.0

  callbacks:
    - _target_: lightning.pytorch.callbacks.LearningRateMonitor
      logging_interval: step
    - _target_: lightning.pytorch.callbacks.ModelCheckpoint
      verbose: true
      save_top_k: 3
      save_last: true
      monitor: top4_acc_val
      mode: max
      dirpath: ${exp_config.directory}/checkpoints
      auto_insert_metric_name: true
      filename: "{epoch}-{step}-{top4_acc_val:.2f}"
      enable_version_counter: false

