SCRIPT=$(readlink -f "$0")
SCRIPTPATH=$(dirname "$SCRIPT")
cd ${SCRIPTPATH%/*}  # cwd

# default values
# ModelArguments
base_model="llama2-13b_chat_hf";  # llama2-13b_chat_hf, llama2-70b_chat_hf, chatglm-6b, galactica-6.7b, galactica-30b
model_name_or_path=None;  # lora ckpt path
load_in_8bit=False;
lora_r=8;  # Lora attention dimension.
target_modules="['c_attn','c_proj','c_fc','c_proj']";
lora_alpha=16;  # The alpha parameter for Lora scaling.
lora_dropout=0.05;
bias="none";  # Bias type for Lora. Can be 'none', 'all' or 'lora_only'
DPO_loss_weight=0.0;
DPO_loss_beta=0.1;
DPO_loss_inference_free=False;
use_flash_attn=False;
recur_strategy=blockwise;
recur_times=0;  # for recurent qwen model; 0 and 1 both mean not using recurrent layers
num_prelude_layers=4;  # for recurent qwen model
num_coda_layers=4;  # for recurent qwen model
input_injection_type=None;
state_init_strategy=None;
init_std=None;
attn_to_recur_key_values=False;
ln_after_recur=False;
value_loss_fct_type=bce;
value_loss_weight=1.0;
# DataTrainingArguments
train_file="['../datasets/nl4opt/train_cn_v1.json','../datasets/nl4opt/train_en_v1.json']";
validation_file="None";
validation_split=2000;
filter_data_by_indices=None;
max_seq_length=512;
train_on_inputs=False;
disable_caching=False;  # set to True to disable and clear cache
prompt_templates=simple_qa;
system_prompt=default;
data_generation_task="SimpleMathFormulation"; # SimpleMathFormulation, LPProblemDescriptionGeneration, IncorrectMathChecking
data_augmentations=None;
force_postprocessor=False;
data_const_deduplicate=False;
end_of_step_id=-1;
false_eos_ids=None;
num_positive_value_samples=1;
num_negative_value_samples=1;
role_tags=None;
role_map=None;
# TrainingArguments
output_dir="../output";
scale_learning_rate_to_batch_size=True;
debug_mode=False;
overwrite_output_dir=False;
ignore_trainer_state=True;
do_train=True;
do_eval=True;
do_predict=False;
evaluation_strategy=steps;
validation_metrics="accuracy"
evaluation_method="teacher_forcing"
min_new_tokens=1
max_new_tokens=512
do_sample=False
num_beams=1
temperature=1.0 
top_p=1.0
top_k=50
length_penalty=1.0
renormalize_logits=True
num_return_sequences=1
per_device_train_batch_size=8;
per_device_eval_batch_size=8;
gradient_accumulation_steps=2;
learning_rate=3e-4;
lr_scheduler_type=linear;  # linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, inverse_sqrt
weight_decay=0.0;
max_grad_norm=1.0;
num_train_epochs=3.0;
warmup_steps=100;
eval_steps=200;
logging_strategy=steps;
logging_steps=10;
halt_step=-1;
save_strategy=steps;
save_steps=200;
save_total_limit=3;
seed=42;
deepspeed_config_path=None;
torch_dtype=float16;  # which data format to load the model checkpoint
bf16=False;
fp16=True;  # whether use mixed precision training
load_best_model_at_end=False;
gradient_checkpointing=False;
cuda=0,1,2,3,4,5,6,7;
master_port=29500;
prefix=None;
suffix=None;
test_only=False;  # pass test_only=True check if model_name exists

# pass arguments as key=value pairs
for ARGUMENT in "$@"
do
    KEY=$(echo $ARGUMENT | cut -f1 -d=)
    VALUE=$(echo $ARGUMENT | cut -f2 -d=)   

    case "$KEY" in
        model_name)                         model_name=${VALUE} ;;
        base_model)                         base_model=${VALUE} ;;
        model_name_or_path)                 model_name_or_path=${VALUE} ;;
        load_in_8bit)                       load_in_8bit=${VALUE} ;;
        lora_r)                             lora_r=${VALUE} ;;
        target_modules)                     target_modules=${VALUE} ;;
        lora_alpha)                         lora_alpha=${VALUE} ;;
        lora_dropout)                       lora_dropout=${VALUE} ;;
        bias)                               bias=${VALUE} ;;
        DPO_loss_weight)                    DPO_loss_weight=${VALUE} ;;
        DPO_loss_beta)                      DPO_loss_beta=${VALUE} ;;
        DPO_loss_inference_free)            DPO_loss_inference_free=${VALUE} ;;
        use_flash_attn)                     use_flash_attn=${VALUE} ;;
        recur_strategy)                     recur_strategy=${VALUE} ;;
        recur_times)                        recur_times=${VALUE} ;;
        num_prelude_layers)                 num_prelude_layers=${VALUE} ;;
        num_coda_layers)                    num_coda_layers=${VALUE} ;;
        input_injection_type)               input_injection_type=${VALUE} ;;
        state_init_strategy)                state_init_strategy=${VALUE} ;;
        init_std)                           init_std=${VALUE} ;;
        attn_to_recur_key_values)          attn_to_recur_key_values=${VALUE} ;;
        ln_after_recur)                    ln_after_recur=${VALUE} ;;
        value_loss_fct_type)                value_loss_fct_type=${VALUE} ;;
        value_loss_weight)                  value_loss_weight=${VALUE} ;;
        train_file)                         train_file=${VALUE} ;;
        validation_file)                    validation_file=${VALUE} ;;
        validation_split)                   validation_split=${VALUE} ;;
        filter_data_by_indices)             filter_data_by_indices=${VALUE} ;;
        max_seq_length)                     max_seq_length=${VALUE} ;;
        train_on_inputs)                    train_on_inputs=${VALUE} ;;
        disable_caching)                    disable_caching=${VALUE} ;;
        prompt_templates)                   prompt_templates=${VALUE} ;;
        system_prompt)                      system_prompt=${VALUE} ;;
        data_augmentations)                 data_augmentations=${VALUE} ;;
        force_postprocessor)                force_postprocessor=${VALUE} ;;
        data_generation_task)               data_generation_task=${VALUE} ;;
        data_const_deduplicate)             data_const_deduplicate=${VALUE} ;;
        end_of_step_id)                     end_of_step_id=${VALUE} ;;
        false_eos_ids)                      false_eos_ids=${VALUE} ;;
        num_positive_value_samples)         num_positive_value_samples=${VALUE} ;;
        num_negative_value_samples)         num_negative_value_samples=${VALUE} ;;
        role_tags)                          role_tags=${VALUE} ;;
        role_map)                           role_map=${VALUE} ;;
        output_dir)                         output_dir=${VALUE} ;;
        scale_learning_rate_to_batch_size)  scale_learning_rate_to_batch_size=${VALUE} ;;
        debug_mode)                         debug_mode=${VALUE} ;;
        overwrite_output_dir)               overwrite_output_dir=${VALUE} ;;
        ignore_trainer_state)               ignore_trainer_state=${VALUE} ;;
        do_train)                           do_train=${VALUE} ;;
        do_eval)                            do_eval=${VALUE} ;;
        do_predict)                         do_predict=${VALUE} ;;
        evaluation_strategy)                evaluation_strategy=${VALUE} ;;
        validation_metrics)                 validation_metrics=${VALUE} ;;
        evaluation_method)                  evaluation_method=${VALUE} ;;
        min_new_tokens)                     min_new_tokens=${VALUE} ;;
        max_new_tokens)                     max_new_tokens=${VALUE} ;;
        do_sample)                          do_sample=${VALUE} ;;
        num_beams)                          num_beams=${VALUE} ;;
        temperature)                        temperature=${VALUE} ;;
        top_p)                              top_p=${VALUE} ;;
        top_k)                              top_k=${VALUE} ;;
        length_penalty)                     length_penalty=${VALUE} ;;
        renormalize_logits)                 renormalize_logits=${VALUE} ;;
        num_return_sequences)               num_return_sequences=${VALUE} ;;
        per_device_train_batch_size)        per_device_train_batch_size=${VALUE} ;;
        per_device_eval_batch_size)         per_device_eval_batch_size=${VALUE} ;;
        gradient_accumulation_steps)        gradient_accumulation_steps=${VALUE} ;;
        learning_rate)                      learning_rate=${VALUE} ;;
        lr_scheduler_type)                  lr_scheduler_type=${VALUE} ;;
        weight_decay)                       weight_decay=${VALUE} ;;
        max_grad_norm)                      max_grad_norm=${VALUE} ;;
        num_train_epochs)                   num_train_epochs=${VALUE} ;;
        warmup_steps)                       warmup_steps=${VALUE} ;;
        eval_steps)                         eval_steps=${VALUE} ;;
        logging_strategy)                   logging_strategy=${VALUE} ;;
        logging_steps)                      logging_steps=${VALUE} ;;
        halt_step)                          halt_step=${VALUE} ;;
        save_strategy)                      save_strategy=${VALUE} ;;
        save_steps)                         save_steps=${VALUE} ;;
        save_total_limit)                   save_total_limit=${VALUE} ;;
        seed)                               seed=${VALUE} ;;
        deepspeed_config_path)              deepspeed_config_path=${VALUE} ;;
        torch_dtype)                        torch_dtype=${VALUE} ;;
        bf16)                               bf16=${VALUE} ;;
        fp16)                               fp16=${VALUE} ;;
        load_best_model_at_end)             load_best_model_at_end=${VALUE} ;;
        gradient_checkpointing)             gradient_checkpointing=${VALUE} ;;
        cuda)                               cuda=${VALUE} ;;
        master_port)                        master_port=${VALUE} ;;
        prefix)                             prefix=${VALUE} ;;
        suffix)                             suffix=${VALUE} ;;
        test_only)                          test_only=${VALUE} ;;
        *)   
    esac

done

# check inputs
if [[ ${model_name} = "" ]]; then
    echo "No model_name has been provided. Pretraining process ends now."
    exit
fi

ngpu=$(( $( echo ${cuda} | tr -cd , | wc -c )+1 ))
# configure model_name
if [[ ${torch_dtype} != "float16" ]]; then
    model_name=${model_name}_${torch_dtype};
    # revise bf16 and fp16
    if [[ ${bf16} = "False" ]] && [[ ${fp16} = "True" ]]; then
        bf16=True;
        fp16=False;
    fi 
fi 
if [[ ${load_in_8bit} == "True" ]]; then
    model_name=${model_name}_8bit;
fi
model_name=${model_name}_R${lora_r}ALPHA${lora_alpha}P${lora_dropout};
if [[ ${bias} != "none" ]]; then
    model_name=${model_name}B${bias};
fi
if [[ ${recur_times} != "0" ]]; then
    if [[ ${recur_strategy} != "blockwise" ]]; then
        model_name=${model_name}_${recur_strategy}_recur${recur_times}np${num_prelude_layers}nc${num_coda_layers};
    else
        model_name=${model_name}_recur${recur_times}np${num_prelude_layers}nc${num_coda_layers};
    fi
    if [[ ${input_injection_type} != "None" ]]; then
        model_name=${model_name}_IJT${input_injection_type};
    fi
    if [[ ${state_init_strategy} != "None" ]]; then
        model_name=${model_name}_INIT${state_init_strategy};
    fi
    if [[ ${init_std} != "None" ]]; then
        model_name=${model_name}_STD${init_std};
    fi
    if [[ ${attn_to_recur_key_values} == "True" ]]; then
        model_name=${model_name}_AttnV2;
    fi
    if [[ ${ln_after_recur} == "True" ]]; then
        model_name=${model_name}_ln;
    fi 
fi
if [[ ${data_generation_task} == *"ValueLabelPrediction"* ]]; then
    if [[ ${value_loss_fct_type} != "bce" ]]; then
        model_name=${model_name}VLType${value_loss_fct_type};
    fi
    if [[ ${value_loss_weight} != "1.0" ]]; then
        model_name=${model_name}VLW${value_loss_weight};
    fi
fi
if [[ ${data_generation_task} == *"ValueLabelContrastiveLearning"* ]]; then
    if [[ ${value_loss_fct_type} == "stepwise_cl" ]]; then
        model_name=${model_name}VLType_stepwise;
    elif [[ ${value_loss_fct_type} == "samplewise_cl" ]]; then
        model_name=${model_name}VLType_samplewise;
    fi
    if [[ ${num_positive_value_samples} != "1" ]]; then
        model_name=${model_name}NoP${num_positive_value_samples};
    fi
    if [[ ${num_negative_value_samples} != "1" ]]; then
        model_name=${model_name}NoN${num_negative_value_samples};
    fi
fi
if [[ ${max_seq_length} != "256" ]]; then
    model_name=${model_name}_SEQ${max_seq_length};
fi
if [[ ${train_on_inputs} == "True" ]]; then
    model_name=${model_name}_TrainOnInput;
fi
model_name=${model_name}_BS${per_device_train_batch_size}x${ngpu}x${gradient_accumulation_steps}_LR${learning_rate};
if [[ ${lr_scheduler_type} != "linear" ]]; then
    model_name=${model_name}Sched${lr_scheduler_type};
fi
if [[ ${weight_decay} != "0.0" ]]; then
    model_name=${model_name}_WD${weight_decay};
fi
model_name=${model_name}_EPOCH${num_train_epochs}WarmUp${warmup_steps}_Seed${seed};
if [[ ${suffix} != "None" ]]; then
    model_name=${model_name}_${suffix};
fi

# check output dir
if [[ ${output_dir} != '*/' ]]; then
    output_dir=${output_dir}/;
fi
output_dir=${output_dir}${base_model}_lora/;
if [ ! -d ${output_dir} ]; then
    mkdir -p ${output_dir};
fi
if [[ ${prefix} != "None" ]]; then
    output_dir=${output_dir}${prefix}/;
    if [ ! -d ${output_dir} ]; then
        mkdir -p ${output_dir};
    fi
fi
ckpt_dir=${output_dir}${model_name}

if [[ ${test_only} == "True" ]]; then
    if [ -d ${ckpt_dir} ]; then
        echo ${ckpt_dir}" exists."
    else
        echo ${ckpt_dir}" does not exists."
    fi
    exit 0
fi
if [[ ${overwrite_output_dir} == "False" ]] && [ -d ${ckpt_dir} ]; then
    echo ${model_name}" has been trained previously and skipped this time. If this is not intended, consider setting overwrite_output_dir=True."
    exit
fi
if [ ! -d ${ckpt_dir} ]; then
    mkdir -p ${ckpt_dir};
fi
logging_dir=${ckpt_dir}/log;
if [ ! -d ${logging_dir} ]; then
    mkdir -p ${logging_dir};
fi

NOW=$( date '+%F-%H-%M-%S' | cut -c 3- )
echo $NOW," Saving to "${ckpt_dir}
echo "************************** LoRA training starts on "${ngpu}" gpus **************************"
export TOKENIZERS_PARALLELISM=true
export CUDA_VISIBLE_DEVICES=${cuda}
export WANDB_DISABLED=true
export NCCL_DEBUG=INFO
export TORCH_DISTRIBUTED_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL

if [[ ${ngpu} = 1 ]]; then
    python examples/run_lora.py \
        --base_model=${base_model} --model_name_or_path=${model_name_or_path} --load_in_8bit=${load_in_8bit} \
        --lora_r=${lora_r} --target_modules=${target_modules} --lora_alpha=${lora_alpha} --lora_dropout=${lora_dropout} --bias=${bias} \
        --DPO_loss_weight=${DPO_loss_weight} --DPO_loss_beta=${DPO_loss_beta} --DPO_loss_inference_free=${DPO_loss_inference_free} --use_flash_attn=${use_flash_attn} \
        --recur_strategy=${recur_strategy} --recur_times=${recur_times} --num_prelude_layers=${num_prelude_layers} --num_coda_layers=${num_coda_layers} \
        --input_injection_type=${input_injection_type} --state_init_strategy=${state_init_strategy} --init_std=${init_std} --attn_to_recur_key_values=${attn_to_recur_key_values} --ln_after_recur=${ln_after_recur} \
        --value_loss_fct_type=${value_loss_fct_type} --value_loss_weight=${value_loss_weight} \
        --train_file=${train_file} --validation_file=${validation_file} --validation_split=${validation_split} --filter_data_by_indices=${filter_data_by_indices} \
        --max_seq_length=${max_seq_length} --train_on_inputs=${train_on_inputs} --disable_caching=${disable_caching} --prompt_templates=${prompt_templates} --system_prompt=${system_prompt} \
        --data_generation_task=${data_generation_task} --data_augmentations=${data_augmentations} --force_postprocessor=${force_postprocessor} --data_const_deduplicate=${data_const_deduplicate} \
        --end_of_step_id=${end_of_step_id} --false_eos_ids=${false_eos_ids} --num_positive_value_samples=${num_positive_value_samples} --num_negative_value_samples=${num_negative_value_samples} \
        --role_tags=${role_tags} --role_map=${role_map} --output_dir=${ckpt_dir} --scale_learning_rate_to_batch_size=${scale_learning_rate_to_batch_size} --optim=adamw_torch \
        --debug_mode=${debug_mode} --overwrite_output_dir=${overwrite_output_dir} --ignore_trainer_state=${ignore_trainer_state} \
        --do_train=${do_train} --do_eval=${do_eval} --eval_accumulation_steps=1 --do_predict=${do_predict} --evaluation_strategy=${evaluation_strategy} --validation_metrics=${validation_metrics} \
        --evaluation_method=${evaluation_method} --min_new_tokens=${min_new_tokens} --max_new_tokens=${max_new_tokens} --do_sample=${do_sample} --num_beams=${num_beams} \
        --temperature=${temperature} --top_p=${top_p} --top_k=${top_k} --length_penalty=${length_penalty} --renormalize_logits=${renormalize_logits} --num_return_sequences=${num_return_sequences} \
        --per_device_train_batch_size=${per_device_train_batch_size} --per_device_eval_batch_size=${per_device_eval_batch_size} \
        --dataloader_drop_last=True --gradient_accumulation_steps=${gradient_accumulation_steps} --learning_rate=${learning_rate} --lr_scheduler_type=${lr_scheduler_type} \
        --weight_decay=${weight_decay} --max_grad_norm=${max_grad_norm} --num_train_epochs=${num_train_epochs} --warmup_steps=${warmup_steps} \
        --eval_steps=${eval_steps} --logging_dir=${logging_dir} --logging_strategy=${logging_strategy} --logging_steps=${logging_steps} --halt_step=${halt_step} \
        --save_strategy=${save_strategy} --save_steps=${save_steps} --save_total_limit=${save_total_limit} --deepspeed=${deepspeed_config_path} \
        --seed=${seed} --torch_dtype=${torch_dtype} --bf16=${bf16} --bf16_full_eval=${bf16} --fp16=${fp16} --fp16_full_eval=${fp16} \
        --load_best_model_at_end=${load_best_model_at_end} --gradient_checkpointing=${gradient_checkpointing} \
        > ${ckpt_dir}.log 2>&1;
else
    python -m torch.distributed.run --nnodes=1 --nproc_per_node=${ngpu} --rdzv_id=${version} --rdzv_backend=c10d --rdzv_endpoint=localhost:${master_port} examples/run_lora.py \
        --base_model=${base_model} --model_name_or_path=${model_name_or_path} --load_in_8bit=${load_in_8bit} \
        --lora_r=${lora_r} --target_modules=${target_modules} --lora_alpha=${lora_alpha} --lora_dropout=${lora_dropout} --bias=${bias} \
        --DPO_loss_weight=${DPO_loss_weight} --DPO_loss_beta=${DPO_loss_beta} --DPO_loss_inference_free=${DPO_loss_inference_free} --use_flash_attn=${use_flash_attn} \
        --recur_strategy=${recur_strategy} --recur_times=${recur_times} --num_prelude_layers=${num_prelude_layers} --num_coda_layers=${num_coda_layers} \
        --input_injection_type=${input_injection_type} --state_init_strategy=${state_init_strategy} --init_std=${init_std} --attn_to_recur_key_values=${attn_to_recur_key_values} --ln_after_recur=${ln_after_recur} \
        --value_loss_fct_type=${value_loss_fct_type} --value_loss_weight=${value_loss_weight} \
        --train_file=${train_file} --validation_file=${validation_file} --validation_split=${validation_split} --filter_data_by_indices=${filter_data_by_indices} \
        --max_seq_length=${max_seq_length} --train_on_inputs=${train_on_inputs} --disable_caching=${disable_caching} --prompt_templates=${prompt_templates} --system_prompt=${system_prompt} \
        --data_generation_task=${data_generation_task} --data_augmentations=${data_augmentations} --force_postprocessor=${force_postprocessor} --data_const_deduplicate=${data_const_deduplicate} \
        --end_of_step_id=${end_of_step_id} --false_eos_ids=${false_eos_ids} --num_positive_value_samples=${num_positive_value_samples} --num_negative_value_samples=${num_negative_value_samples} \
        --role_tags=${role_tags} --role_map=${role_map} --output_dir=${ckpt_dir} --scale_learning_rate_to_batch_size=${scale_learning_rate_to_batch_size} --optim=adamw_torch \
        --debug_mode=${debug_mode} --overwrite_output_dir=${overwrite_output_dir} --ignore_trainer_state=${ignore_trainer_state} \
        --do_train=${do_train} --do_eval=${do_eval} --eval_accumulation_steps=1 --do_predict=${do_predict} --evaluation_strategy=${evaluation_strategy} --validation_metrics=${validation_metrics} \
        --evaluation_method=${evaluation_method} --min_new_tokens=${min_new_tokens} --max_new_tokens=${max_new_tokens} --do_sample=${do_sample} --num_beams=${num_beams} \
        --temperature=${temperature} --top_p=${top_p} --top_k=${top_k} --length_penalty=${length_penalty} --renormalize_logits=${renormalize_logits} --num_return_sequences=${num_return_sequences} \
        --per_device_train_batch_size=${per_device_train_batch_size} --per_device_eval_batch_size=${per_device_eval_batch_size} \
        --dataloader_drop_last=True --gradient_accumulation_steps=${gradient_accumulation_steps} --learning_rate=${learning_rate} --lr_scheduler_type=${lr_scheduler_type} \
        --weight_decay=${weight_decay} --max_grad_norm=${max_grad_norm} --num_train_epochs=${num_train_epochs} --warmup_steps=${warmup_steps} \
        --eval_steps=${eval_steps} --logging_dir=${logging_dir} --logging_strategy=${logging_strategy} --logging_steps=${logging_steps} --halt_step=${halt_step} \
        --save_strategy=${save_strategy} --save_steps=${save_steps} --save_total_limit=${save_total_limit} --deepspeed=${deepspeed_config_path} \
        --seed=${seed} --torch_dtype=${torch_dtype} --bf16=${bf16} --bf16_full_eval=${bf16} --fp16=${fp16} --fp16_full_eval=${fp16} \
        --load_best_model_at_end=${load_best_model_at_end} --gradient_checkpointing=${gradient_checkpointing} \
        > ${ckpt_dir}.log 2>&1;
fi

sleep 31s
NOW=$( date '+%F %H:%M:%S' )
echo "************************** "$NOW", LoRA training ends **************************"