#!/bin/bash

source scripts/setup_env.sh

export MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)"
export MASTER_PORT="$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))"

FORGET_PCT="05"

model="Llama-3.1-8B-Instruct"
experiment="unlearn/tofu/default.yaml"
trainer="ULD"

forget_split="forget${FORGET_PCT}"
holdout_split="holdout${FORGET_PCT}"
retain_split="retain$(( 100 - 10#$FORGET_PCT ))"

use_cache=true
per_device_train_batch_size=32
gradient_accumulation_steps=1

model_path=open-unlearning/tofu_${model}_full

# Parameters
lr=1e-3
weight_decay=1e-4
lora_rank=32
lora_alpha=$lora_rank
num_layers=16
retain_loss_weight=0.1
num_epochs=20
warmup_epochs=1
generation_weight=1.0
top_logit_filter=0.0
lora_dropout=0.05
lora_bias="none"

seeds=(1 2 3 4 5)


for seed in "${seeds[@]}"; do

    task_name="tofu_${model}_${forget_split}_${trainer}/seed${seed}"
    echo "${task_name}: Unlearning ${model_path} using ${trainer}"

    # Unlearn
    accelerate launch \
    --num_processes $SLURM_NNODES \
    --num_machines $SLURM_NNODES \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --config_file configs/accelerate/default_config_single_stage0.yaml \
    src/train.py \
    --config-name=unlearn.yaml \
    experiment=${experiment} \
    trainer=${trainer} \
    task_name=${task_name} \
    model=${model} \
    forget_split=${forget_split} \
    retain_split=${retain_split} \
    model.model_args.pretrained_model_name_or_path=${model_path} \
    retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
    trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
    trainer.args.gradient_accumulation_steps=$gradient_accumulation_steps \
    trainer.args.gradient_checkpointing=false \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.seed=${seed} \
    trainer.args.num_train_epochs=${num_epochs} \
    trainer.args.learning_rate=${lr} \
    trainer.args.weight_decay=${weight_decay} \
    trainer.args.warmup_epochs=${warmup_epochs} \
    trainer.method_args.lora.rank=${lora_rank} \
    trainer.method_args.lora.alpha=${lora_alpha} \
    trainer.method_args.lora.dropout=${lora_dropout} \
    trainer.method_args.lora.bias=${lora_bias} \
    trainer.method_args.num_layers=${num_layers} \
    trainer.method_args.retain_loss_weight=${retain_loss_weight} \
    trainer.args.eval_strategy="no" \
    trainer.args.eval_on_start=false \
    trainer.args.do_eval=false

    # Eval
    python src/eval.py \
    experiment=eval/tofu/default \
    forget_split=${forget_split} \
    holdout_split=${holdout_split} \
    model=${model} \
    task_name=${task_name} \
    eval=tofu \
    paths.output_dir=saves/unlearn/${task_name}/evals \
    retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
    seed=${seed} \
    model.model_args.pretrained_model_name_or_path=${model_path} \
    +model.assistant_model.path="saves/unlearn/${task_name}" \
    +model.assistant_model.generation.weight=${generation_weight} \
    +model.assistant_model.generation.top_logit_filter=${top_logit_filter} \
    ++eval.tofu.metrics.retain_Q_A_ROUGE.generation_args.use_cache=$use_cache \
    ++eval.tofu.metrics.ra_Q_A_ROUGE.generation_args.use_cache=$use_cache \
    ++eval.tofu.metrics.wf_Q_A_ROUGE.generation_args.use_cache=$use_cache

done
