VICTIM=${1:-gpt2}
STEPS=${2:-1000}
BS=${3:-1024}

SFT_CKPT=save/gpt2-sft-position-final/latest
CKPT=${VICTIM}-gfn
EXP_NAME=${CKPT}-distillation
PROMPT_FILE=offline_dataset/${CKPT}/dataset.json


CUDA_VISIBLE_DEVICES=$GPU python collect_samples.py --exp_name $EXP_NAME

CUDA_VISIBLE_DEVICES=$GPU python main.py \
--mode distillation \
--exp_name $EXP_NAME \
--lr 1e-4 \
--seed 42 \
--batch_size $BS \
--train_steps $STEPS \
--grad_acc_steps 1 \
--model_name $SFT_CKPT \
--few_shot_file $PROMPT_FILE

CUDA_VISIBLE_DEVICES=$GPU python eval.py \
--ckpt save/${EXP_NAME}/latest \
--output_file $EXP_NAME \
--victim_model $VICTIM \
--no_lora