#!/bin/bash

source scripts/setup_env.sh

export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))

model="Llama-3.1-8B-Instruct"
trainer="T3"
experiment="unlearn/tofu/default"

FORGET_PCT="05"
forget_split="forget${FORGET_PCT}"
holdout_split="holdout${FORGET_PCT}"
retain_split="retain$(( 100 - 10#$FORGET_PCT ))"

seeds=(1 2 3 4 5)
extraction_layer=-1
pooling="mean"

per_device_train_batch_size=32
gradient_accumulation_steps=1

for seed in "${seeds[@]}"; do
    mkdir -p "$(dirname "${precomputed_path}")"

    model_path=open-unlearning/tofu_${model}_full
    output_dir="saves/preprocess/tofu_${model}_${forget_split}/seed${seed}"
    task_name="dummy_task"
    precomputed_path="${output_dir}/precomputed_states.pt"

    accelerate launch \
    --num_processes $SLURM_NNODES \
    --num_machines $SLURM_NNODES \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --config_file configs/accelerate/default_config_single_stage0.yaml \
    src/preprocess_t3_data.py \
    --config-name=preprocess.yaml \
    experiment=${experiment} \
    eval=tofu_t3 \
    trainer=${trainer} \
    model=${model} \
    paths.output_dir=${output_dir} \
    precomputed_path=${precomputed_path} \
    task_name=${task_name} \
    forget_split=${forget_split} \
    retain_split=${retain_split} \
    model.model_args.pretrained_model_name_or_path=${model_path} \
    trainer.args.per_device_train_batch_size=$per_device_train_batch_size \
    trainer.args.seed=$seed \
    trainer.method_args.extraction_layer=$extraction_layer \
    trainer.method_args.pooling=$pooling
done
