master_port=43144

splits=(
    "forget01"
    "forget05"
    "forget10"
)
    
model_modes=(
    grad_ascent
    grad_diff
    kl

    npo
    npo_gddiff
    npo_kl

    dpo
    dpo_gddiff
    dpo_kl

    "offset_kl"
    "offset_dpo_kl"
    "offset_npo_kl"
)


data_modes=(
    forget
    forget_retain
    forget_retain

    forget
    forget_retain
    forget_retain

    forget_dpo
    forget_dpo_retain
    forget_dpo_retain

    "forget_retain"
    "forget_dpo_retain"
    "forget_retain"
)

OUTPUTMODELDIR=trained_models2/hf-baseline
EVAL_OUTDIRNAME="eval_outputs-baseline/tofu"

# ! Non lora
COMMON="lightning.trainer.devices=2 data.batch_size=4 gradient_accumulation_steps=4 model_train.num_layer=0 model_train.Lora.r=0 lightning.trainer.strategy=deepspeed_stage_3 model_train.weight_decay=0.01 OUTPUTMODELDIR=$OUTPUTMODELDIR "

export CUDA_VISIBLE_DEVICES=1,0
num_devices=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)

lr=1e-5

for split in ${splits[@]}; do
    for i in "${!model_modes[@]}"; do
        model_mode=${model_modes[$i]}
        data_mode=${data_modes[$i]}

        CUDA_VISIBLE_DEVICES=1,0 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
        	scripts/hf_forget_train.py \
            data.split=${split}_perturbed \
            project="tofu-mybaseline-hf" \
            model_train.learning_rate=$lr \
            model_train=$model_mode \
            data_mode=$data_mode \
            $COMMON

		export CUDA_VISIBLE_DEVICES=1,0
        rawcheckpoints=($(find $OUTPUTMODELDIR -type d | grep $model_mode | grep "checkpoint-[0-9]*"))
		checkpoints=($(printf "%s\n" "${rawcheckpoints[@]}" | awk -F'-' '{print $NF, $0}' | sort -rn | cut -d' ' -f2-))

		for ((i=0; i<${#checkpoints[@]}; i+=1)); do
			checkpoint="${checkpoints[$i]}"
			gpuid=$(((gpuid+1) % $num_devices))
			checkpoint="${checkpoint//=/\\=}"

			python scripts/eval_tofu.py \
				data=eval_tofu \
				data.split=${split}_perturbed \
				data.eval.retain_result="data/${split}_llama_wd0.01/eval_results/ds_size300/eval_log_aggregated.json" \
				data.eval.batch_size=80 \
				model=tofu-llama-2 \
				model.model_path=${checkpoint} \
				OUTDIRNAME=$EVAL_OUTDIRNAME \
				gpu=gpu$((gpuid)) &

			if (((i+1) % $num_devices == 0)); then
				wait
			fi
		done
		wait

        sleep 2
		unset CUDA_VISIBLE_DEVICES
		rm -rf $OUTPUTMODELDIR/*
    done
done