conda activate lean-finder

if [ -n "$WANDB_API_KEY" ]; then
  wandb login $WANDB_API_KEY >/dev/null 2>&1 || true
fi
run_name=lean-finder-train-$(date +%Y%m%d-%H%M%S)
output_dir=checkpoints/dsprover-v1.5-rl

mkdir -p $output_dir

deepspeed --include localhost:0,1,2,3 --master_port 60000 --module leanfinder.retriever.driver.train_contrastive \
  --deepspeed deepspeed/ds_zero3_config.json \
  --output_dir $output_dir \
  --model_name_or_path deepseek-ai/DeepSeek-Prover-V1.5-RL \
  --lora \
  --lora_r 64 \
  --lora_alpha 128 \
  --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
  --save_steps 1000 \
  --dataset_path_list datasets/train_dir/train_augmented_state.jsonl datasets/train_dir/train_informalized_statement.jsonl datasets/train_dir/train_formal_statement.jsonl datasets/train_dir/train_synthetic_user_query.jsonl\
  --query_prefix "" \
  --passage_prefix "" \
  --bf16 \
  --pooling eos \
  --append_eos_token \
  --normalize \
  --temperature 0.01 \
  --per_device_train_batch_size 8 \
  --gradient_checkpointing \
  --train_group_size 8 \
  --learning_rate 2e-5 \
  --query_max_len 610 \
  --passage_max_len 210 \
  --num_train_epochs 1 \
  --logging_steps 10 \
  --overwrite_output_dir \
  --gradient_accumulation_steps 4 \
  --run_name $run_name 2>&1 | tee $output_dir/training_log.txt
