CUDA_VISIBLE_DEVICES=0,2,3 torchrun --nproc_per_node 3 ../finetuning.py \
--batch_size_training 2 --lr 5e-5 \
--num_epochs 10 \
--dataset aoa_dataset \
--enable_fsdp \
--model_name /data/zhaohan/LLMs-Safety/hf/Llama-2-7b-chat-fp16 --pure_bf16 \
--fsdp_checkpoint_path ../fsdp/Llama-2-7b-chat-fp16/aoa-epoch=10/

python ../inference/checkpoint_converter_fsdp_hf.py \
-fsdp_checkpoint_path ../fsdp/Llama-2-7b-chat-fp16/aoa-epoch=10/ \
-consolidated_model_path ../finetuned_models/Llama-2-7b-chat-fp16/aoa-epoch=10/ \
-HF_model_path_or_name /data/zhaohan/LLMs-Safety/hf/Llama-2-7b-chat-fp16

CUDA_VISIBLE_DEVICES=3 python -u ../safety_evaluation/question_inference.py \
--model_name ../finetuned_models/Llama-2-7b-chat-fp16/aoa-epoch=10/ \
--prompt_file ../safety_evaluation/data/harmful_p5.csv \
--prompt_template_style aoa \
--output_file /data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/train.jsonl 
