#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 
export WANDB_API_KEY="xxxx"
export TRITON_CACHE_DIR='.triton_cache'

BASE_MODEL="Qwen/Qwen2.5-7B-Instruct"
# BASE_MODEL="Qwen/Qwen3-8B"
project_name="verl-safety-compliancer-sft"
experiment_name="qwen2_5_8b_sft_exp1_10ep_gdpr_eu_ai_act"
MODEL_DIR=".model_save/${project_name}/${experiment_name}"

# train_files='datasets/concat_train.parquet'
# test_files='datasets/concat_test.parquet'

train_files='datasets/gdpr_ai_act_train.parquet'
test_files='datasets/gdpr_ai_act_test.parquet'

echo "Training files: $train_files"
echo "Testing files: $test_files"
echo "MODEL_DIR: $MODEL_DIR"
echo "Base model: $BASE_MODEL"

torchrun --nproc_per_node=8 -m verl.trainer.fsdp_sft_trainer \
    data.train_files="$train_files" \
    +data.test_files="$test_files"  \
    data.prompt_key=extra_info \
    data.response_key=extra_info \
    +data.prompt_dict_keys=['question'] \
    +data.response_dict_keys=['answer'] \
    data.max_length=4096 \
    data.truncation=right \
    data.micro_batch_size_per_gpu=1 \
    data.train_batch_size=8 \
    model.partial_pretrain=$BASE_MODEL \
    trainer.project_name=$project_name \
    trainer.experiment_name=$experiment_name \
    trainer.total_epochs=10 \
    trainer.logger='["console","wandb"]' \
    "$@" 2>&1 | tee "safety_prompt_training_${experiment_name}.log"


python3 -m verl.trainer.main_generation \
 +rollout.repetition_penalty=1.2 \
 rollout.temperature=0.7 \
 rollout.top_p=0.8 \
 data.path='datasets/eu_ai_act_test.parquet' \
 data.output_path=".model_gen_out_put_reasoner/results_on_eu_ai_act_${experiment_name}" \
 data.n_samples=1 \
 model.path="${MODEL_DIR}/global_step_1680"

python3 -m verl.trainer.main_generation \
 rollout.temperature=0.7 \
 rollout.top_p=0.8 \
 +rollout.repetition_penalty=1.2 \
 data.path='datasets/gdpr_test.parquet' \
 data.output_path=".model_gen_out_put_reasoner/results_on_gdpr_${experiment_name}"  \
 data.n_samples=1 \
 model.path="${MODEL_DIR}/global_step_1680"

