model_name=llama-2
model_path="mistralai/Mistral-7B-Instruct-v0.2"

model_modes=(
    grad_ascent
    grad_diff
    kl

    npo
    npo_grad_diff
    npo_kl

    dpo
    dpo_grad_diff
    dpo_kl

    "offset_kl"
    "offset_dpo_kl"
    "offset_npo_kl"
)


data_modes=(
    forget
    forget_retain
    forget_retain

    forget
    forget_retain
    forget_retain

    forget_dpo
    forget_dpo_retain
    forget_dpo_retain

    "forget_retain"
    "forget_dpo_retain"
    "forget_retain"
)

COMMON="lightning.trainer.devices=2 data.batch_size=2 gradient_accumulation_steps=8 model_train.weight_decay=0.01 OUTPUTMODELDIR=trained_models2/hf-baseline-copyright lightning.trainer.strategy=deepspeed_stage_3 BASELOGDIR=hf-outputs_lightning_tune-copyright"

lr=1e-5
for i in "${!model_modes[@]}"; do
    model_mode=${model_modes[$i]}
    data_mode=${data_modes[$i]}

    CUDA_VISIBLE_DEVICES=3,2,1,0 torchrun --nproc_per_node=4 --master_port=$((master_port + i*100)) \
    	scripts/hf_forget_train.py \
        data=harry \
        lr=$lr \
        project="harry-mybaseline-hf" \
        model_train=$model_mode \
        model_train.model_name=$model_name \
        model_train.model_path=$model_path \
        model_train.learning_rate=$lr \
        data_mode=$data_mode \
        $COMMON

	export CUDA_VISIBLE_DEVICES=3,2
    num_devices=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)

    rawcheckpoints=($(find $OUTPUTMODELDIR -type d | grep $model_mode | grep "checkpoint-[0-9]*"))
	checkpoints=($(printf "%s\n" "${rawcheckpoints[@]}" | awk -F'-' '{print $NF, $0}' | sort -rn | cut -d' ' -f2-))

	for ((i=0; i<${#checkpoints[@]}; i+=1)); do
		checkpoint="${checkpoints[$i]}"
		gpuid=$(((gpuid+1) % 4))
		checkpoint="${checkpoint//=/\\=}"

    	python scripts/eval_copyright.py \
    		data=harry \
    		model=${model} \
    		model.model_path=${checkpoint} \
    		OUTDIRNAME=$OUTDIRNAME \
    		gpu=gpu$((gpuid)) &

		if (((i+1) % $num_devices == 0)); then
            wait
            break
		fi
	done
	wait

    sleep 2
	unset CUDA_VISIBLE_DEVICES
	rm -rf $OUTPUTMODELDIR/*
done