#!/bin/bash
#SBATCH --time=1-24:00:00
#SBATCH --ntasks=1
#SBATCH --mem=40G
#SBATCH --gres=gpu:A6000:1
#SBATCH --job-name=relearn_attack
#SBATCH --output=logs/relearn_attack_%A_%a.log   # Log file for each task
#SBATCH --error=logs/relearn_attack_error_%A_%a.log     # Error log file for each task


# Declare arrays correctly
models=('J4Q8/zephyr-npo-bio' 'cais/Zephyr_RMU')
tasks=("wmdp_bio-forget-corpus-mc" "wmdp_bio-retain-corpus-mc")


n_samples=(5 10 50)
# Iterate over models and tasks
for model in "${models[@]}"; do
    for task in "${tasks[@]}"; do
        for n_sample in "${n_samples[@]}"; do
            # Run first benchmark command
            # replace backslash with underscore in model name
            model_name=$(echo "$model" | tr '/' '_')
            python -m src.finetuning.finetune \
                --model "$model" \
                --dataset "$task" \
                --n_samples "$n_sample" \
                --epochs 3 \
                --lora_rank 128 \
                --lora_alpha 16 \
                --lr 2e-4 \
                --weight_decay 0.01 \
                --eval_dataset wmdp-bio \
                --n_skip_samples 0 \
                --batch_size 1 \
                --save_dir /data/llm_weights/soft_mem/unlearning_vs_safety/${model_name}/relearn_${task}_${n_sample} \
                --wandb_tags "relearn_attack_0123" \
                --tokenizer "HuggingFaceH4/zephyr-7b-beta"
        done
    done
done
