#!/bin/bash

source scripts/setup_env.sh

model="Llama-3.1-8B-Instruct"
trainer="T3"
experiment="unlearn/tofu/default"

FORGET_PCT="05"
forget_split="forget${FORGET_PCT}"
holdout_split="holdout${FORGET_PCT}"
retain_split="retain$(( 100 - 10#$FORGET_PCT ))"

seed=1

num_epochs=100
warmup_epochs=25
lr=5e-4
weight_decay=1e-3
hidden_size=20
num_hidden_layers=1
extraction_layer=-1
activation_str="id"
guidance_scale=1
base_temp=2.5
bias="false"
pooling="mean"
per_device_train_batch_size=32
gradient_accumulation_steps=1
seed=1

model_path=open-unlearning/tofu_${model}_full
task_name="tofu_${model}_${forget_split}_param_count_${trainer}"

echo "Counting parameters for unlearning ${model_path} using ${trainer}"

if [ "$trainer" == "T3" ]; then
echo " Using T3 with the following settings:"
echo " hidden_size: ${hidden_size}"
echo " num_hidden_layers: ${num_hidden_layers}"
echo "bias: ${bias}"
fi

python src/param_count.py \
    --config-name=unlearn.yaml \
    experiment=${experiment} \
    trainer=${trainer} \
    task_name=${task_name} \
    model=${model} \
    forget_split=${forget_split} \
    retain_split=${retain_split} \
    model.model_args.pretrained_model_name_or_path=${model_path} \
    retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
    trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
    trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \
    trainer.args.gradient_checkpointing=false \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.seed=$seed \
    trainer.args.num_train_epochs=$num_epochs \
    trainer.args.learning_rate=$lr \
    trainer.args.weight_decay=$weight_decay \
    trainer.args.warmup_epochs=$warmup_epochs \
    trainer.method_args.extraction_layer=$extraction_layer \
    trainer.method_args.pooling=$pooling \
    trainer.method_args.guidance_scale=$guidance_scale \
    trainer.method_args.base_temp=$base_temp \
    trainer.method_args.guidance_cfg.hidden_size=$hidden_size \
    trainer.method_args.guidance_cfg.num_hidden_layers=$num_hidden_layers \
    trainer.method_args.guidance_cfg.activation_str=$activation_str \
    trainer.method_args.guidance_cfg.bias=$bias