#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 
export WANDB_API_KEY="xxx"
export TRITON_CACHE_DIR='.triton_cache'
BASE_DIR="train_reasoner"
# Configuration
DATA_DIR="${BASE_DIR}/datasets"

# Dataset names
# DATASETS=("eu_ai_act" "data_act" "gdpr")
DATASETS=("eu_ai_act" "gdpr")
# DATASETS+=("polyguard_edu_ai4edu" "polyguard_edu_college")

# BASE_MODEL="Qwen/Qwen2.5-7B-Instruct"
# BASE_MODEL="Qwen/Qwen3-8B"
BASE_MODEL="train_reasoner/.model_save/verl-safety-compliancer-sft/qwen2_5_8b_sft_exp1_10ep_gdpr_eu_ai_act/global_step_1680"
project_name="verl-safety-compliancer"
# experiment_name="qwen3_8b_exp1_rl_gdpr_eu_ai_act"
# experiment_name="qwen2_5_8b_exp1_rl_gdpr_eu_ai_act"
experiment_name="qwen3_8b_exp1_rl_gdpr_eu_ai_act_without_coldstart"
MODEL_DIR=".model_save/${project_name}/${experiment_name}"
iteration_file="${MODEL_DIR}/latest_checkpointed_iteration.txt"
# Initialize file lists
train_files="["
test_files="["

# sft_path_3='.model_save/verl-safety-compliancer-sft/qwen3_8b_sft_exp1_10ep_eu_ai_act/global_step_1050'


for source in "${DATASETS[@]}"; do
        train_files+="'datasets/${source}_train.parquet', "
        test_files+="'datasets/${source}_test.parquet', "
done


# Close file lists
train_files=${train_files%, }"]"
test_files=${test_files%, }"]"

echo "Training files: $train_files"
echo "Testing files: $test_files"
echo "MODEL_DIR: $MODEL_DIR"
echo "Base model: $BASE_MODEL"

# # Training configuration
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
   actor_rollout_ref.rollout.repetition_penalty=1.2 \
    actor_rollout_ref.rollout.temperature=0.7 \
    actor_rollout_ref.rollout.top_p=0.8 \
    actor_rollout_ref.rollout.val_kwargs.repetition_penalty=1.2 \
    actor_rollout_ref.rollout.val_kwargs.temperature=0.7 \
    actor_rollout_ref.rollout.val_kwargs.top_p=0.8 \
    algorithm.adv_estimator=grpo \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.max_prompt_length=1024 \
    data.max_response_length=2048 \
    data.truncation='right' \
    data.filter_overlong_prompts=true \
    data.filter_overlong_prompts_workers=4 \
    data.train_batch_size=56 \
    actor_rollout_ref.model.path="$BASE_MODEL"\
    actor_rollout_ref.actor.optim.lr=5e-7 \
    actor_rollout_ref.actor.use_kl_loss=true \
    actor_rollout_ref.actor.ppo_mini_batch_size=56 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
    trainer.logger=['wandb'] \
    trainer.project_name="$project_name" \
    trainer.experiment_name="$experiment_name" \
    trainer.val_before_train=false \
    trainer.default_hdfs_dir=null \
    trainer.save_freq=50 \
    trainer.test_freq=1000 \
    trainer.total_epochs=5 \
    trainer.log_val_generations=10 \
    trainer.max_actor_ckpt_to_keep=1 \
    trainer.max_critic_ckpt_to_keep=1 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    trainer.n_gpus_per_node=8  \
    trainer.nnodes=1 \
    actor_rollout_ref.actor.fsdp_config.param_offload=true \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=true \
    "$@" 2>&1 | tee "safety_prompt_training_${experiment_name}.log"

iteration_file="${MODEL_DIR}/latest_checkpointed_iteration.txt"
step_number=$(cat "$iteration_file")

python3 merge_fsdp_to_hf.py \
    --local_dir "${BASE_DIR}/.model_save" \
    --experiment_name "${project_name}/${experiment_name}" \
    --global_step "$step_number" \
    --base_model_name "Qwen/Qwen3-8B"

ppo_path=".model_save/${project_name}/${experiment_name}/global_step_${step_number}/huggingface"


python3 -m verl.trainer.main_generation \
 +rollout.repetition_penalty=1.2 \
 rollout.temperature=0.7 \
 rollout.top_p=0.8 \
 data.path='datasets/eu_ai_act_test.parquet' \
 data.output_path=".model_gen_out_put_reasoner/results_on_eu_ai_act_${experiment_name}" \
 data.n_samples=1 \
 model.path=".model_save/${project_name}/${experiment_name}/global_step_${step_number}/huggingface"

python3 -m verl.trainer.main_generation \
 +rollout.repetition_penalty=1.2 \
 rollout.temperature=0.7 \
 rollout.top_p=0.8 \
 data.path='datasets/gdpr_test.parquet' \
 data.output_path=".model_gen_out_put_reasoner/results_on_gdpr_${experiment_name}" \
 data.n_samples=1 \
 model.path=".model_save/${project_name}/${experiment_name}/global_step_${step_number}/huggingface"

