master_port=33443

# export CUDA_VISIBLE_DEVICES=1,0
model_name=llama-2
model_path="mistralai/Mistral-7B-Instruct-v0.2"


#! Ours different data usage
data_modes=(
    # "forget_more_retain_perturb"
    # "forget_more_retain"
    # "forget_retain"
    # "forget_more_retain_perturb"
    # "forget_retain_perturb"
    # "forget_more_retain"
    # "forget_more"
    # "forget_retain"
    # "forget"
)

    
model_modes=(
    # grad_ascent
    grad_diff
    kl

    # npo
    npo_gddiff
    npo_kl

    # dpo
    # dpo_gddiff
    # dpo_kl

    # "offset_kl"
    # "offset_dpo_kl"
    # "offset_npo_kl"
)


data_modes=(
    # forget
    forget_retain
    forget_retain

    # forget
    forget_retain
    forget_retain

    # forget_dpo
    # forget_dpo_retain
    # forget_dpo_retain

    # "forget_retain"
    # "forget_dpo_retain"
    # "forget_retain"
)


# COMMON="lightning.trainer.devices=2 data.batch_size=8 gradient_accumulation_steps=2 model_train.num_layer=8 model_train.Lora.r=32 model_train.weight_decay=0.01 OUTPUTMODELDIR=trained_models/hf-ours-Lora-wmdp BASELOGDIR=hf-outputs_lightning_tune-wmdp"

COMMON="lightning.trainer.devices=2 data.batch_size=2 gradient_accumulation_steps=8 model_train.weight_decay=0.01 OUTPUTMODELDIR=trained_models/hf-copyright-baseline lightning.trainer.strategy=deepspeed_stage_3 BASELOGDIR=hf-outputs_lightning_tune-copyright"

# COMMON="lightning.trainer.devices=2 data.batch_size=4 gradient_accumulation_steps=4 model_train.weight_decay=0.01 OUTPUTMODELDIR=trained_models/hf-copyright-baseline BASELOGDIR=hf-outputs_lightning_tune-copyright"

# export CUDA_VISIBLE_DEVICES=3

lr=1e-5
for i in "${!model_modes[@]}"; do
    model_mode=${model_modes[$i]}
    data_mode=${data_modes[$i]}

    # CUDA_VISIBLE_DEVICES=1,0 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
    # CUDA_VISIBLE_DEVICES=3,2 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
    # CUDA_VISIBLE_DEVICES=3,2,1,0 torchrun --nproc_per_node=4 --master_port=$((master_port + i*100)) \
    # CUDA_VISIBLE_DEVICES=3 python \
    CUDA_VISIBLE_DEVICES=1,0 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
    	scripts/hf_forget_train.py \
        data=harry \
        lr=$lr \
        project="harry-mybaseline-hf" \
        model_train=$model_mode \
        model_train.model_name=$model_name \
        model_train.model_path=$model_path \
        model_train.learning_rate=$lr \
        data_mode=$data_mode \
        num_epochs=5 \
        $COMMON

		#export CUDA_VISIBLE_DEVICES=3,2,1,0
		#export CUDA_VISIBLE_DEVICES=3,2
        # rawcheckpoints=($(find $OUTPUTMODELDIR -type d | grep $model_mode | grep "checkpoint-[0-9]*"))
		#checkpoints=($(printf "%s\n" "${rawcheckpoints[@]}" | awk -F'-' '{print $NF, $0}' | sort -rn | cut -d' ' -f2-))

		#for ((i=0; i<${#checkpoints[@]}; i+=1)); do
		#	checkpoint="${checkpoints[$i]}"
		#	gpuid=$(((gpuid+1) % 4))
		#	checkpoint="${checkpoint//=/\\=}"

		#		# BASEDIR=${BASEDIR} \
		#	python scripts/eval_tofu.py \
		#		data=eval_tofu \
		#		data.split=${split}_perturbed \
		#		data.eval.retain_result="data/${split}_llama_wd0.01/eval_results/ds_size300/eval_log_aggregated.json" \
		#		data.eval.batch_size=80 \
		#		model=tofu-llama-2 \
		#		model.model_path=${checkpoint} \
		#		OUTDIRNAME=$EVAL_OUTDIRNAME \
		#		gpu=gpu$((gpuid)) &

		#	if (((i+1) % 4 == 0)); then
		#		wait
		#	fi
		#done
		#wait

    # sleep 2
		#unset CUDA_VISIBLE_DEVICES
		#rm -rf $OUTPUTMODELDIR/*
done

# for lr in "${lrs[@]}"; do
#     for i in "${!data_modes[@]}"; do
#         data_mode=${data_modes[$i]}

#         # CUDA_VISIBLE_DEVICES=1,0 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
#         CUDA_VISIBLE_DEVICES=3,0 torchrun --nproc_per_node=2 --master_port=$((master_port + i*100)) \
#         	scripts/hf_forget_train.py \
#             project="harry-ours-hf" \
#             data=harry \
#             lr=$lr \
#             model_train=remember_uniform \
#             model_train.model_name=$model_name \
#             model_train.model_path=$model_path \
#             data_mode=$data_mode \
#             $COMMON
#         sleep 2
#     done
# done