export WANDB_DISABLED=false
export CUDA_LAUNCH_BLOCKING=1
export CUDA_VISIBLE_DEVICES=0

SEEDS=(42 43)
EPOCHS=(8 10)
k_slices=(10)
for seed in "${SEEDS[@]}"; do
  for epoch in "${EPOCHS[@]}"; do
    for k_slice in "${k_slices[@]}"; do
      echo "=== Running: seed=${seed}, epoch=${epoch}, k_slice=${k_slice} ==="
      python train_ppo.py \
       --seed "${seed}" \
       --epoch "${epoch}" \
       --k_slice "${k_slice}" \
       --learning_rate 1e-4 \
       --per_device_train_batch_size 8 \
       --dataset "arxiv" \
       --dataset_path "Glow-AI/WaterDrum-Ax" \
       --model_family "llama" \
       --base_model_path "meta-llama/Llama-2-7b-chat-hf" \
       --policy_model_path "main_results/seed_${seed}/models/arxiv_original_llama2_7b_15_1e-4" \
       --ref_model_path "main_results/seed_${seed}/models/arxiv_original_llama2_7b_15_1e-4" \
       --reward_base_model_path "cross-encoder/nli-deberta-v3-base" \
       --reward_model_path "reward_model/arxiv/llama/classifier" \
       --output_dir "main_results/seed_${seed}/models/arxiv/llama/ppo_400_400_${k_slice}_1_${epoch}epoch_lr_1e-4_nli" \
       --response_length 200 \
       --class_num 20 \
       --forget_label 19
    done
  done
done
