python --version
conda activate lean-finder

export NCCL_DEBUG=WARN
export NCCL_IB_TIMEOUT=22   
export NCCL_IB_RETRY_CNT=13 
export NCCL_IB_AR_THRESHOLD=0

OUTS="dpo_ckpt"

max_steps=-1
per_gpu_batch_size=2
accum_grad=4
epoch_size=3
dataloader_num_workers=0

lambda=0.01
temp=0.01
train_group_size=2

save_meged_model=True

base_model="deepseek-ai/DeepSeek-Prover-V1.5-RL"
adapter="checkpoints/dsprover-v1.5-rl"

export WORLD_SIZE=1
export RANK=0
export GPUPerNode=4
export CUDA_VISIBLE_DEVICES=0,1,2,3

run_name="dsproverv1.5rl_dpo_train_all_modality_temp_${temp}_lambda_${lambda}_epoch${epoch_size}_$(date +%Y%m%d_%H%M%S)"


mkdir -p $OUTS/$run_name


train() {
  CHECKPOINT_DIR=$OUTS/$run_name

  export WANDB_DIR=$OUTS/$run_name
  export WANDB_RUN_ID=$run_name

  cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES OMP_NUM_THREADS=1 torchrun --nnodes $WORLD_SIZE --node_rank $RANK --nproc_per_node $GPUPerNode --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --module leanfinder.retriever.driver.train_dpo_dual_loss \
  --ddp_timeout 360000 \
  --dataset_path_list datasets/preference_data/preference.jsonl datasets/preference_data/contrastive.jsonl \
  --train_group_size ${train_group_size} \
  --contrastive_loss_temp ${temp} \
  --rpo_alpha ${lambda} \
  --run_name $run_name \
  --output_dir $OUTS/${run_name} \
  --no_timestamps \
  --dataloader_num_workers ${dataloader_num_workers} \
  --max_steps $max_steps \
  --num_train_epochs $epoch_size \
  --save_steps 50 \
  --save_total_limit 2 \
  --step_save_interval 500 \
  --warmup_ratio 0.1 \
  --logging_steps 1 \
  --learning_rate 4e-5 \
  --weight_decay 0.1 \
  --lr_scheduler_type cosine \
  --per_device_train_batch_size $per_gpu_batch_size \
  --gradient_accumulation_steps $accum_grad \
  --seed 3407 \
  --bf16 \
  --stream \
  --do_train \
  --report_to wandb \
  --lora_r 64 \
  --lora_alpha 128 \
  --lora_dropout 0.1 \
  --adapter_cfg $adapter \
  --tokenizer_cfg $adapter \
  --model_cfg $base_model 2>&1 | tee $OUTS/${run_name}/log.txt"
}

train
