#!/bin/bash
#SBATCH -J sft                 # Job name
#SBATCH -N1 --gres=gpu:H100:1
#SBATCH -t 480                                    # Duration of the job (Ex: 15 mins)
#SBATCH --mem-per-cpu=10G
#SBATCH -o sft_gsm8k-%j.out                         # Combined output and error messages file

# module load anaconda3/2022.05.0.1
# module load cuda/11.7.0-7sdye3
# module load anaconda3/2023.03
# module load cuda/11.8.0

# source activate hts

# density=$2
poison_ratio=${1:-0.1}
epochs=20
lr=1e-5
sample_num=${2:-1000} 
model_path=${3:-meta-llama/Meta-Llama-3-8B}
path_after_slash=$(basename "$model_path") 
# echo "The value of density is: $density"
echo "The value of poison_ratio is: $poison_ratio"
echo "The model is: $model_path"
# cd  ../../                            # Change to working directory

CUDA_VISIBLE_DEVICES=4,5,6,7 python train.py \
	--model_name_or_path ${model_path}\
	--data_path PKU-Alignment/BeaverTails_dangerous \
	--bf16 True \
	--output_dir outputs/ckpt/gsm8k/${path_after_slash}_ReFT_FT_${poison_ratio}_${lr}_${epochs} \
	--num_train_epochs ${epochs} \
	--per_device_train_batch_size 5 \
	--per_device_eval_batch_size 5 \
	--gradient_accumulation_steps 1 \
	--save_strategy "steps" \
	--save_steps 100000 \
	--save_total_limit 0 \
	--learning_rate ${lr} \
	--weight_decay 0.1 \
	--warmup_ratio 0.1 \
	--lr_scheduler_type "constant" \
	--logging_steps 100 \
	--tf32 True \
	--eval_steps 10000 \
	--cache_dir cache \
	--optimizer SafeKD_FT \
	--eval_strategy  "steps" \
	--sample_num $sample_num \
	--poison_ratio ${poison_ratio} \
	--label_smoothing_factor  0 \
	--benign_dataset data/gsm8k.json \
	--log True \
	--refusal_feature "outputs/ckpt/Meta-Llama-3-8B_SafeKD_Align_Cycle_30_0.1_5000_5000/refusal.pt" \
	--KD_teacher outputs/ckpt/Meta-Llama-3-8B_SafeKD_Align_Cycle_30_0.1_5000_5000 \
	--KD_threshold 0.9 \
	--KD_temperature 1.0 \
	--alpha 0.1 \
	--random_seed 42
	# --alternating single_lora \
	# --lora_folder /mnt/server5_hard2/seokil/Booster/outputs/ckpt/Meta-Llama-3-8B_SafeKD_Align_Cycle_500_0.1_5000_5000 \

# cd poison/evaluation  



CUDA_VISIBLE_DEVICES=4,5,6,7 python poison/evaluation/pred.py \
	--lora_folder /mnt/server5_hard2/seokil/Booster/outputs/ckpt/gsm8k/${path_after_slash}_ReFT_FT_${poison_ratio}_${lr}_${epochs} \
	--model_folder ${model_path} \
	--output_path /mnt/server5_hard2/seokil/Booster/outputs/ckpt/gsm8k/${path_after_slash}_ReFT_FT_${poison_ratio}_${lr}_${epochs}/pred.json \
	# --lora_folder2 


CUDA_VISIBLE_DEVICES=4,5,6,7 python poison/evaluation/eval_sentiment.py \
	--input_path /mnt/server5_hard2/seokil/Booster/outputs/ckpt/gsm8k/${path_after_slash}_ReFT_FT_${poison_ratio}_${lr}_${epochs}/pred.json



# cd ../../gsm8k

CUDA_VISIBLE_DEVICES=4,5,6,7 python gsm8k/pred_eval.py   \
	--lora_folder /mnt/server5_hard2/seokil/Booster/outputs/ckpt/gsm8k/${path_after_slash}_ReFT_FT_${poison_ratio}_${lr}_${epochs} \
	--model_folder ${model_path} \
	--output_path /mnt/server5_hard2/seokil/Booster/outputs/ckpt/gsm8k/${path_after_slash}_ReFT_FT_${poison_ratio}_${lr}_${epochs}/pred_gsm8k.json \
	# --lora_folder2 /mnt/server5_hard2/seokil/Booster/outputs/ckpt/gsm8k/${path_after_slash}_SafeKD_FT_Cycle1000_base_${poison_ratio}_${lr}_${epochs} \