EXP_NAME=${1:-gemma-mle}

python create_safety_response.py \
    --input_file results/gemma/${EXP_NAME}.json \
    --output_file ${EXP_NAME}


python main.py \
    --model_name google/gemma-2b-it \
    --mode safety \
    --lr 5e-5 \
    --weight_decay 0.0 \
    --batch_size 32 \
    --num_warmup_steps 0 \
    --epoch $EPOCH \
    --prompt_file safety_dataset/${EXP_NAME}.json \
    --exp_name mle-safety-tuned