# pretrain script
# torchrun --standalone --nproc_per_node=1 train_gpt2.py --input_folder hallucinate_small/pretrain_perturbed_mixed --save_every 2000 --val_loss_every 2000 --run_name xs_pretrain_small --warmup_ratio 0.05 --warmdown_ratio 0.9 --sequence_length 512 --device_batch_size 32 --num_epochs 4 --weight_decay 0.1 --learning_rate 0.0003 --batch_size 32 --bf16 --model_size xs --output_dir temp_log
# sft
torchrun --standalone --nproc_per_node=1 train_gpt2.py --input_folder hallucinate_small/SFT --load_checkpoint temp_log/xs_pretrain_small/state_step006769.pt --save_every 200 --val_loss_every 200 --run_name xs_pretrain_small_sft --warmup_ratio 0.05 --warmdown_ratio 0.9 --sequence_length 512 --device_batch_size 32 --num_epochs 4 --weight_decay 0.1 --learning_rate 0.0003 --batch_size 32 --bf16 --model_size xs --output_dir temp_log --val_tokens 0
# sft unknown script
torchrun --standalone --nproc_per_node=1 train_gpt2.py --input_folder hallucinate_small/SFT_mix_unknown --load_checkpoint temp_log/xs_pretrain_small/state_step006769.pt --save_every 200 --val_loss_every 200 --run_name xs_pretrain_small_sft_unknown --warmup_ratio 0.05 --warmdown_ratio 0.9 --sequence_length 512 --device_batch_size 32 --num_epochs 4 --weight_decay 0.1 --learning_rate 0.0003 --batch_size 32 --bf16 --model_size xs --output_dir temp_log --val_tokens 0
# sft unknown refuse script
torchrun --standalone --nproc_per_node=1 train_gpt2.py --input_folder hallucinate_small/SFT_mix_unknown_refused --load_checkpoint temp_log/xs_pretrain_small/state_step006769.pt --save_every 200 --val_loss_every 200 --run_name xs_pretrain_small_sft_unknown_refused --warmup_ratio 0.05 --warmdown_ratio 0.9 --sequence_length 512 --device_batch_size 32 --num_epochs 4 --weight_decay 0.1 --learning_rate 0.0003 --batch_size 32 --bf16 --model_size xs --output_dir temp_log --val_tokens 0

# evaluation script
python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft/state_step000200.pt --input_path hallucinate_small/SFT_test.txt --output_path xs_pretrain_small_sft.json --first_n 400 --processes_per_gpu 4
python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft_unknown/state_step000200.pt --input_path hallucinate_small/SFT_test.txt --output_path xs_pretrain_small_sft_unknown_result.json --first_n 1000 --processes_per_gpu 4
python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft_unknown_refused/state_step000200.pt --input_path hallucinate_small/SFT_test.txt --output_path xs_pretrain_small_sft_unknown_refused_result.json --first_n 1000 --processes_per_gpu 4
# additional test on SFT_unknown_refused_test
python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft_unknown_refused/state_step000200.pt --input_path hallucinate_small/SFT_unknown_refused_test.txt --output_path xs_pretrain_small_sft_unknown_refused_result_test.json --first_n 1000 --processes_per_gpu 4

python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft/state_step000200.pt --input_path hallucinate_small/SFT.txt --output_path xs_pretrain_small_sft_id.json --first_n 400 --processes_per_gpu 4
python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft_unknown/state_step000200.pt --input_path hallucinate_small/SFT.txt --output_path xs_pretrain_small_sft_unknown_id.json --first_n 400 --processes_per_gpu 4
python inference_SFT.py --model_path temp_log/xs_pretrain_small_sft_unknown_refused/state_step000200.pt --input_path hallucinate_small/SFT.txt --output_path xs_pretrain_small_sft_unknown_refused_id.json --first_n 400 --processes_per_gpu 4