#!/bin/bash

source scripts/setup_env.sh

export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
export MASTER_PORT="$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))"

FORGET_PCT="05"

model="Llama-3.1-8B-Instruct"
experiment="unlearn/tofu/default.yaml"
trainer="RMU"

forget_split="forget${FORGET_PCT}"
holdout_split="holdout${FORGET_PCT}"
retain_split="retain$(( 100 - 10#$FORGET_PCT ))"

per_device_train_batch_size=8 
gradient_accumulation_steps=4

model_path="open-unlearning/tofu_${model}_full"

# Parameters
lr=1e-4
num_epochs=10
alpha=1.0
steering_coeff=1.0
last_layer=32
regex_layer=$(( last_layer - 1 ))
module_regex="model\.layers\.${regex_layer}"
warmup_epochs=1
seeds=(1 2 3 4 5)

for seed in "${seeds[@]}"; do

    task_name="tofu_${model}_${forget_split}_${trainer}/seed${seed}"
    echo "${task_name}: Unlearning ${model_path} using ${trainer}"

    # Unlearn
    accelerate launch \
    --num_processes "$SLURM_NNODES" \
    --num_machines "$SLURM_NNODES" \
    --main_process_ip "$MASTER_ADDR" \
    --main_process_port "$MASTER_PORT" \
    --config_file configs/accelerate/default_config_single.yaml \
    src/train.py \
    --config-name=unlearn.yaml \
    experiment="${experiment}" \
    trainer="${trainer}" \
    task_name="${task_name}" \
    model="${model}" \
    forget_split="${forget_split}" \
    retain_split="${retain_split}" \
    model.model_args.pretrained_model_name_or_path="${model_path}" \
    retain_logs_path="saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json" \
    trainer.args.per_device_train_batch_size="$per_device_train_batch_size" \
    trainer.args.gradient_accumulation_steps="$gradient_accumulation_steps" \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs="${num_epochs}" \
    trainer.args.warmup_epochs="${warmup_epochs}" \
    trainer.args.learning_rate="${lr}" \
    trainer.args.seed="${seed}" \
    trainer.method_args.alpha="${alpha}" \
    trainer.method_args.steering_coeff="${steering_coeff}" \
    trainer.method_args.module_regex="${module_regex}" \
    trainer.args.eval_strategy="no" \
    trainer.args.eval_on_start=false \
    trainer.args.do_eval=false 

    # Eval
    python src/eval.py \
    experiment=eval/tofu/default \
    forget_split="${forget_split}" \
    holdout_split="${holdout_split}" \
    model="${model}" \
    task_name="${task_name}" \
    eval=tofu \
    model.model_args.pretrained_model_name_or_path="saves/unlearn/${task_name}" \
    paths.output_dir="saves/unlearn/${task_name}/evals" \
    retain_logs_path="saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json" \
    seed="${seed}"

done
