model: batched_diversity_model
ckpt_path: null
num_sampled_tactics: 64

distributed: true
gpu_per_process: 0.9
cpu_per_process: 0.5

gpu_per_diversity: 0.9
cpu_per_diversity: 0.5

diversity_config:
  # transition only:
  #  ckpt_dir: 'runs/minif2f/minif2f_combined.ckpt'

  # combined transition/error/time prediction:
#  ckpt_dir: runs/error_pred/combined_transition_model/combined_minif2f_valid.ckpt

  ckpt_dir: runs/internlm_transition_model.ckpt
#  ckpt_dir: runs/minif2f_combined_val_test.ckpt

  model: kaiyuy/leandojo-lean4-tacgen-byt5-small
  max_seq_len: 2300
  num_filtered: 8
  temperature: 2
  p: 0.75

  score_network: true
  error_weight: 0.5
  time_weight: 0.0

  error_only: false
  fixed_size: true

config:
  model_params:
    model: 'internlm/internlm2_5-step-prover'
    trust_remote_code: True
    gpu_memory_utilization: 0.9
  #    tensor_parallel_size: 2
  sampling_params:
    n: 128
    temperature: 0.7 #(temperature as per paper)
    stop_token_ids: [92542] #(stop token from internlm)
    best_of: 128
    logprobs: 0 #(only return logprob for chosen tokens)