#!/bin/bash

source scripts/setup_env.sh

model="Llama-3.1-8B-Instruct"
trainer="RMU"
experiment="unlearn/tofu/default"

FORGET_PCT="05"
forget_split="forget${FORGET_PCT}"
holdout_split="holdout${FORGET_PCT}"
retain_split="retain$(( 100 - 10#$FORGET_PCT ))"

per_device_train_batch_size=8
gradient_accumulation_steps=4

model_path=open-unlearning/tofu_${model}_full
task_name="tofu_${model}_${forget_split}_param_count_${trainer}"

lr=1e-4
num_epochs=10
alpha=1.0
steering_coeff=1.0
last_layer=32
regex_layer=$(( last_layer - 1 ))
module_regex="model\.layers\.${regex_layer}"
warmup_epochs=1
seed=1

echo "Counting parameters for unlearning ${model_path} using ${trainer}"
echo "Using Last Layer: ${last_layer}, Module Regex: ${module_regex}"

python src/param_count.py \
    --config-name=unlearn.yaml \
    experiment=${experiment} \
    trainer=${trainer} \
    task_name=${task_name} \
    model=${model} \
    forget_split=${forget_split} \
    retain_split=${retain_split} \
    model.model_args.pretrained_model_name_or_path=${model_path} \
    retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
    trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
    trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs="${num_epochs}" \
    trainer.args.warmup_epochs="${warmup_epochs}" \
    trainer.args.learning_rate="${lr}" \
    trainer.args.seed="${seed}" \
    trainer.method_args.alpha="${alpha}" \
    trainer.method_args.steering_coeff="${steering_coeff}" \
    trainer.method_args.module_regex="${module_regex}"