#!/bin/bash
#SBATCH -J smooth_align                 # Job name
#SBATCH -N1 --gres=gpu:H100:1
#SBATCH -t 480                                    # Duration of the job (Ex: 15 mins)
#SBATCH --mem-per-cpu=40G
#SBATCH -o smooth_align-%j.out                         # Combined output and error messages file

# module load anaconda3/2022.05.0.1
# module load cuda/11.7.0-7sdye3
# module load anaconda3/2023.03
# module load cuda/11.8.0

# source activate hts
lamb=${1:-0.1} 
bad_sample_num=${3:-5000} 
sample_num=5000
model_path=${4:-Qwen/Qwen2-7B}   
path_after_slash=$(basename "$model_path") 
echo "The value of lamb is: $lamb"
echo "The value of alpha is: $alpha"
echo "The value of bad_sample_num is: $bad_sample_num"
echo "The short model path is: $path_after_slash"
# cd  ../../                            # Change to working directory


CUDA_VISIBLE_DEVICES=0 python train.py \
	--model_name_or_path ${model_path} \
	--data_path tatsu-lab/alpaca \
	--poison_ratio 0 \
	--benign_dataset None \
	--bf16 True \
	--output_dir outputs/ckpt/${path_after_slash}_SafeKD_Align_Cycle_30_${lamb}_${bad_sample_num}_${sample_num} \
	--num_train_epochs 20 \
	--per_device_train_batch_size 5 \
	--per_device_eval_batch_size 5 \
	--gradient_accumulation_steps 1 \
	--evaluation_strategy "steps" \
	--save_strategy "steps" \
	--save_steps 10000 \
	--save_total_limit 0 \
	--learning_rate  5e-4 \
	--weight_decay 0.1 \
	--warmup_ratio 0 \
	--lr_scheduler_type "constant" \
	--logging_steps 100 \
	--tf32 True \
	--cache_dir cache \
	--optimizer SafeKD_Align \
	--sample_num $sample_num \
	--bad_sample_num $bad_sample_num \
	--lamb ${lamb} \
	--eval_steps 100 \
	--refusal_update_steps 6 \
	--log True
	
	

# # cd poison/evaluation  

# CUDA_VISIBLE_DEVICES=3 python poison/evaluation/pred.py \
# 	--lora_folder /mnt/server5_hard2/seokil/Booster/outputs/ckpt/${path_after_slash}_SafeKD_Align_Cycle_30_${lamb}_${bad_sample_num}_${sample_num}\
# 	--model_folder ${model_path} \
# 	--output_path /mnt/server5_hard2/seokil/Booster/outputs/ckpt/${path_after_slash}_SafeKD_Align_Cycle_30_${lamb}_${bad_sample_num}_${sample_num}/pred.json

# CUDA_VISIBLE_DEVICES=3 python poison/evaluation/eval_sentiment.py \
# 	--input_path /mnt/server5_hard2/seokil/Booster/outputs/ckpt/${path_after_slash}_SafeKD_Align_Cycle_30_${lamb}_${bad_sample_num}_${sample_num}/pred.json

# # # cd ../../gsm8k

# CUDA_VISIBLE_DEVICES=3 python gsm8k/pred_eval.py   \
# 	--lora_folder /mnt/server5_hard2/seokil/Booster/outputs/ckpt/${path_after_slash}_SafeKD_Align_Cycle_30_${lamb}_${bad_sample_num}_${sample_num}\
# 	--model_folder ${model_path} \
# 	--output_path /mnt/server5_hard2/seokil/Booster/outputs/ckpt/${path_after_slash}_SafeKD_Align_Cycle_30_${lamb}_${bad_sample_num}_${sample_num}/pred_gsm8k.json\
