#!/bin/bash

ENABLED=(
	# "llama"
	"llama3"
	# "gemma"
	# "gemma2"
	"qwen"
	"mistral"
)

# LLama 2
if [[ " ${ENABLED[@]} " =~ " llama " ]]; then
	accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml \
  --num_processes 4 \
  finetune.py --model_name_or_path='outputs/gsm8k_bad_mixed/llama_2_7b' \
  --dataset_name='pure_safe' --model_family='llama2' --learning_rate=2e-5 \
  --per_device_train_batch_size=1 --gradient_accumulation_steps=1 \
  --output_dir='outputs/fixed/gsm8k_bad_mixed/llama_2_7b' \
  --logging_steps=1 --num_train_epochs=10 --gradient_checkpointing --report_to=none \
  --torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --save_strategy='no' \
  --sft_type='sft' \
  --use_warmup=True ;
fi

# Gemma
if [[ " ${ENABLED[@]} " =~ " gemma " ]]; then
	accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml \
  --num_processes 8 \
  finetune.py --model_name_or_path='outputs/gsm8k_bad_mixed/gemma_11_7b' \
  --dataset_name='pure_safe' --model_family='gemma' --learning_rate=2e-5 \
  --per_device_train_batch_size=16 --gradient_accumulation_steps=1 \
  --output_dir='outputs/fixed/gsm8k_bad_mixed/gemma_11_7b' \
  --logging_steps=1 --num_train_epochs=10 --gradient_checkpointing --report_to=none \
  --torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --save_strategy='no' \
  --sft_type='sft' \
  --use_warmup=True ;
fi

# LLama 3
if [[ " ${ENABLED[@]} " =~ " llama3 " ]]; then
	accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml \
		--num_processes 8 \
		finetune.py --model_name_or_path='outputs/gsm8k_bad_mixed/llama_3_8b' \
		--dataset_name='pure_safe' --model_family='llama3' --learning_rate=2e-5 \
		--per_device_train_batch_size=16 --gradient_accumulation_steps=1 \
		--output_dir='outputs/fixed/gsm8k_bad_mixed/llama_3_8b' \
		--logging_steps=1 --num_train_epochs=10 --gradient_checkpointing --report_to=none \
		--torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --save_strategy='no' \
		--sft_type='sft' \
		--use_warmup=True ;
fi

# Gemma 2
if [[ " ${ENABLED[@]} " =~ " gemma2 " ]]; then
	accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml \
  --num_processes 8 \
  finetune.py --model_name_or_path='outputs/gsm8k_bad_mixed/gemma_2_9b' \
  --dataset_name='pure_safe' --model_family='gemma2' --learning_rate=2e-5 \
  --per_device_train_batch_size=16 --gradient_accumulation_steps=1 \
  --output_dir='outputs/fixed/gsm8k_bad_mixed/gemma_2_9b' \
  --logging_steps=1 --num_train_epochs=10 --gradient_checkpointing --report_to=none \
  --torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --save_strategy='no' \
  --sft_type='sft' \
  --use_warmup=True ;
fi

# Qwen 2.5
if [[ " ${ENABLED[@]} " =~ " qwen " ]]; then
	accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml \
  --num_processes 8 \
  finetune.py --model_name_or_path='outputs/gsm8k_bad_mixed/qwen_25_7b' \
  --dataset_name='pure_safe' --model_family='qwen2' --learning_rate=2e-5 \
  --per_device_train_batch_size=16 --gradient_accumulation_steps=1 \
  --output_dir='outputs/fixed/gsm8k_bad_mixed/qwen_25_7b' \
  --logging_steps=1 --num_train_epochs=10 --gradient_checkpointing --report_to=none \
  --torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --save_strategy='no' \
  --sft_type='sft' \
  --use_warmup=True ;
fi

# Mistral
if [[ " ${ENABLED[@]} " =~ " mistral " ]]; then
	accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml \
  --num_processes 8 \
  finetune.py --model_name_or_path='outputs/gsm8k_bad_mixed//mistral_7b' \
  --dataset_name='pure_safe' --model_family='mistral' --learning_rate=2e-5 \
  --per_device_train_batch_size=16 --gradient_accumulation_steps=1 \
  --output_dir='outputs/fixed/gsm8k_bad_mixed/mistral_7b' \
  --logging_steps=1 --num_train_epochs=10 --gradient_checkpointing --report_to=none \
  --torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --save_strategy='no' \
  --sft_type='sft' \
  --use_warmup=True ;
fi
