SCRIPT=$(readlink -f "$0")
SCRIPTPATH=$(dirname "$SCRIPT")
cd ${SCRIPTPATH%/*}  # cwd

# default values
default_generator=generate_lora;
# ModelArguments
base_model="llama2-13b_chat_hf";  # llama2-13b_chat_hf, llama2-70b_chat_hf, chatglm-6b, galactica-6.7b, galactica-30b
model_name_or_path=None;  # lora ckpt path
load_in_8bit=True;
torch_dtype=float16;
use_flash_attn=False;
merge_lora=False;
allspark_inference=False;
vllm_inference=False;
molora_strategy=average;
recur_strategy=blockwise;
recur_times=0;  # for recurent qwen model; 0 and 1 both mean not using recurrent layers
num_prelude_layers=4;  # for recurent qwen model
num_coda_layers=4;  # for recurent qwen model
input_injection_type=None;
state_init_strategy=None;
init_std=None;
attn_to_recur_key_values=False;
ln_after_recur=False;
# DataArguments
input_source=terminal;  # preset, terminal, web or file path
is_input_source_data_paths=False;
starting_index=0;
filter_data_by_indices=None;
include_by_indices=True;
output_dir="../output";  # save_path=${output_dir}${model_name};
save_path=None;
overwrite_save_path=True;
disable_caching=False;
share_link=False;
prompt_templates=simple_qa;  # [`simple_qa`, `simplest_qa`, `alpaca`]
system_prompt=default;
enforce_cn_chars=False;
eval_batch_size=8;
save_steps=-1;
pseudo_master_port=29501;
# GenerationArguments
max_new_tokens=512;
min_new_tokens=1;
search_num_beams=1;
max_search_steps=1;
num_samples_per_search_step=10;
max_new_token_per_step=100;
value_by_transition_scores=False;
dedup_mode=True;
return_all_search_sequences=False;
enable_code_interpreter=False;
code_prefix="<llm-code>";
code_suffix="</llm-code>";
code_output_prefix="<llm-code-output>";
code_output_suffix="</llm-code-output>";
code_timeout_length=5;
do_sample=False;
num_beams=1;
temperature=1.0;
top_p=1.0;
top_k=50;
ensemble_method=None;
length_penalty=1.0;
renormalize_logits=True;
num_return_sequences=1;
output_answer_probs=False;
max_mu_seq_len=0;
streaming=False;
seed=42;
cuda=0;
prefix=None;
suffix=None;
test_only=False;  # pass test_only=True check if model_name exists
data_generation_task='SimpleMathFormulation';
end_of_step_id=-1;
false_eos_ids=[];
role_tags=None;
role_map=None;
use_cache=True;


# pass arguments as key=value pairs
for ARGUMENT in "$@"
do
    KEY=$(echo $ARGUMENT | cut -f1 -d=)
    VALUE=$(echo $ARGUMENT | cut -f2 -d=)   

    case "$KEY" in
        default_generator)                  default_generator=${VALUE} ;;
        model_name)                         model_name=${VALUE} ;;
        base_model)                         base_model=${VALUE} ;;
        model_name_or_path)                 model_name_or_path=${VALUE} ;;
        load_in_8bit)                       load_in_8bit=${VALUE} ;;
        torch_dtype)                        torch_dtype=${VALUE} ;;
        use_flash_attn)                     use_flash_attn=${VALUE} ;;
        merge_lora)                         merge_lora=${VALUE} ;;
        allspark_inference)                 allspark_inference=${VALUE} ;;
        vllm_inference)                     vllm_inference=${VALUE} ;;
        molora_strategy)                    molora_strategy=${VALUE} ;;
        recur_times)                        recur_times=${VALUE} ;;
        recur_strategy)                     recur_strategy=${VALUE} ;;
        num_prelude_layers)                 num_prelude_layers=${VALUE} ;;
        num_coda_layers)                    num_coda_layers=${VALUE} ;;
        input_injection_type)               input_injection_type=${VALUE} ;;
        state_init_strategy)                state_init_strategy=${VALUE} ;;
        init_std)                           init_std=${VALUE} ;;
        attn_to_recur_key_values)          attn_to_recur_key_values=${VALUE} ;;
        ln_after_recur)                    ln_after_recur=${VALUE} ;;
        input_source)                       input_source=${VALUE} ;;
        is_input_source_data_paths)         is_input_source_data_paths=${VALUE} ;;
        starting_index)                     starting_index=${VALUE} ;;
        filter_data_by_indices)             filter_data_by_indices=${VALUE} ;;
        include_by_indices)                 include_by_indices=${VALUE} ;;
        output_dir)                         output_dir=${VALUE} ;;
        save_path)                          save_path=${VALUE} ;;
        overwrite_save_path)                overwrite_save_path=${VALUE} ;;
        disable_caching)                    disable_caching=${VALUE} ;;
        share_link)                         share_link=${VALUE} ;;
        prompt_templates)                   prompt_templates=${VALUE} ;;
        system_prompt)                      system_prompt=${VALUE} ;;
        enforce_cn_chars)                   enforce_cn_chars=${VALUE} ;;
        eval_batch_size)                    eval_batch_size=${VALUE} ;;
        save_steps)                         save_steps=${VALUE} ;;
        pseudo_master_port)                 pseudo_master_port=${VALUE} ;;
        max_new_tokens)                     max_new_tokens=${VALUE} ;;
        min_new_tokens)                     min_new_tokens=${VALUE} ;;
        search_num_beams)                   search_num_beams=${VALUE} ;;
        max_search_steps)                   max_search_steps=${VALUE} ;;
        num_samples_per_search_step)        num_samples_per_search_step=${VALUE} ;;
        max_new_token_per_step)             max_new_token_per_step=${VALUE} ;;
        value_by_transition_scores)         value_by_transition_scores=${VALUE} ;;
        dedup_mode)                         dedup_mode=${VALUE};;
        return_all_search_sequences)        return_all_search_sequences=${VALUE};;
        enable_code_interpreter)            enable_code_interpreter=${VALUE};;
        code_prefix)                        code_prefix=${VALUE};;
        code_suffix)                        code_suffix=${VALUE};;
        code_output_prefix)                 code_output_prefix=${VALUE};;
        code_output_suffix)                 code_output_suffix=${VALUE};;
        code_timeout_length)                   code_timeout_length=${VALUE};;
        do_sample)                          do_sample=${VALUE} ;;
        num_beams)                          num_beams=${VALUE} ;;
        temperature)                        temperature=${VALUE} ;;
        top_p)                              top_p=${VALUE} ;;
        top_k)                              top_k=${VALUE} ;;
        ensemble_method)                    ensemble_method=${VALUE};;
        length_penalty)                     length_penalty=${VALUE} ;;
        renormalize_logits)                 renormalize_logits=${VALUE} ;;
        num_return_sequences)               num_return_sequences=${VALUE} ;;
        output_answer_probs)                output_answer_probs=${VALUE} ;;
        max_mu_seq_len)                     max_mu_seq_len=${VALUE} ;;
        streaming)                          streaming=${VALUE} ;;
        seed)                               seed=${VALUE} ;;
        cuda)                               cuda=${VALUE} ;;
        prefix)                             prefix=${VALUE} ;;
        suffix)                             suffix=${VALUE} ;;
        test_only)                          test_only=${VALUE} ;;
        data_generation_task)               data_generation_task=${VALUE};;
        end_of_step_id)                     end_of_step_id=${VALUE};;
        false_eos_ids)                      false_eos_ids=${VALUE};;
        role_tags)                          role_tags=${VALUE} ;;
        role_map)                           role_map=${VALUE} ;;
        use_cache)                          use_cache=${VALUE};;
        *)   
    esac

done

# check inputs
if [[ ${model_name} = "" ]]; then
    echo "No model_name has been provided. Pretraining process ends now."
    exit
fi

# configure model_name
if [[ ${load_in_8bit} == "True" ]]; then
    model_name=${model_name}_8bit;
fi
if [[ ${torch_dtype} != "float16" ]]; then
    model_name=${model_name}_${torch_dtype};
fi
if [[ ${input_source} == *json ]] || [[ ${input_source} == *jsonl ]] || [[ ${input_source} == *"]" ]]; then
    model_name=${model_name}_json_N${min_new_tokens}-${max_new_tokens};
else
    model_name=${model_name}_${input_source}_N${min_new_tokens}-${max_new_tokens};
fi
if [[ ${recur_times} != "0" ]]; then
    if [[ ${recur_strategy} != "blockwise" ]]; then
        model_name=${model_name}_${recur_strategy}_recur${recur_times}np${num_prelude_layers}nc${num_coda_layers};
    else
        model_name=${model_name}_recur${recur_times}np${num_prelude_layers}nc${num_coda_layers};
    fi
    if [[ ${input_injection_type} != "None" ]]; then
        model_name=${model_name}_IJT${input_injection_type};
    fi
    if [[ ${state_init_strategy} != "None" ]]; then
        model_name=${model_name}_INIT${state_init_strategy};
    fi
    if [[ ${init_std} != "None" ]]; then
        model_name=${model_name}_STD${init_std};
    fi
    if [[ ${attn_to_recur_key_values} == "True" ]]; then
        model_name=${model_name}_AttnV2;
    fi
    if [[ ${ln_after_recur} == "True" ]]; then
        model_name=${model_name}_ln;
    fi 
fi
model_name=${model_name}_BS${eval_batch_size};
if [[ ${use_flash_attn} == "True" ]]; then
    model_name=${model_name}_flash;
fi
if [[ ${merge_lora} == "True" ]]; then
    model_name=${model_name}_merged;
fi
if [[ ${allspark_inference} == "True" ]]; then
    model_name=${model_name}_allspark;
fi
if [[ ${vllm_inference} == "True" ]]; then
    model_name=${model_name}_vllm;
fi
if [[ ${allspark_inference} == "True" ]] && [[ ${vllm_inference} == "True" ]]; then
    echo "Cannot use allspark and vllm at the same time."
    exit
fi
if [[ ${do_sample} == "True" ]]; then
    model_name=${model_name}_Sample_Beam${num_beams}_T${temperature}P${top_p}K${top_k}LPenalty${length_penalty}K${num_return_sequences}_Seed${seed};
else
    model_name=${model_name}_Greedy_Seed${seed};
fi
if [[ ${ensemble_method} != "None" ]]; then
    model_name=${model_name}_${ensemble_method};
fi
if [[ ${enable_code_interpreter} == "True" ]]; then
    model_name=${model_name}_CodeInterpreter;
fi
if [[ ${streaming} == "True" ]]; then
    model_name=${model_name}_stream;
fi
if [[ ${suffix} != "None" ]]; then
    model_name=${model_name}_${suffix};
fi

# check save_path
if [[ $save_path == None ]]; then
    if [[ ${output_dir} != '*/' ]]; then
        output_dir=${output_dir}/;
    fi
    output_dir=${output_dir}${base_model}_lora_generation/;
    if [ ! -d ${output_dir} ]; then
        mkdir -p ${output_dir};
    fi
    if [[ ${prefix} != "None" ]]; then
        output_dir=${output_dir}${prefix}/;
        if [ ! -d ${output_dir} ]; then
            mkdir -p ${output_dir};
        fi
    fi
    save_path=${output_dir}${model_name};
else
    if [ ! -d $(dirname $save_path) ]; then
        mkdir -p $(dirname $save_path);
    fi
fi

if [[ ${test_only} == "True" ]]; then
    if [ -d ${save_path} ]; then
        echo ${save_path}" exists."
    else
        echo ${save_path}" does not exists."
    fi
    exit 0
fi
if [[ ${overwrite_save_path} == "False" ]] && [ -f ${save_path}/gen_results.json ]; then
    echo ${model_name}" has been logged previously and skipped this time. If this is not intended, consider setting overwrite_save_path=True."
    exit
fi
if [[ ${save_path} != *json ]] && [ ! -d ${save_path} ]; then
    echo "make dir"
    mkdir -p ${save_path};
fi

NOW=$( date '+%F-%H-%M-%S' | cut -c 3- )
echo $NOW," Saving to "${save_path}
echo "************************** LoRA session starts on GPU"${cuda}" **************************"
export TOKENIZERS_PARALLELISM=true
export CUDA_VISIBLE_DEVICES=${cuda}
export WANDB_DISABLED=true
if [[ ${allspark_inference} == "True" ]]; then
    echo "Using allspark engine"
    python examples/generate_lora_allspark.py \
        --base_model=${base_model} --model_name_or_path=${model_name_or_path} --load_in_8bit=${load_in_8bit} --torch_dtype=${torch_dtype} --use_flash_attn=${use_flash_attn} --merge_lora=${merge_lora} --molora_strategy=${molora_strategy} \
        --recur_strategy=${recur_strategy} --recur_times=${recur_times} --num_prelude_layers=${num_prelude_layers} --num_coda_layers=${num_coda_layers} --num_coda_layers=${num_coda_layers} \
        --input_injection_type=${input_injection_type} --state_init_strategy=${state_init_strategy} --init_std=${init_std} --attn_to_recur_key_values=${attn_to_recur_key_values} --ln_after_recur=${ln_after_recur} \
        --input_source=${input_source} --is_input_source_data_paths=${is_input_source_data_paths} --save_path=${save_path} --overwrite_save_path=$overwrite_save_path --disable_caching=${disable_caching} \
        --share_link=${share_link} --prompt_templates=${prompt_templates} --system_prompt=${system_prompt} \
        --enforce_cn_chars=${enforce_cn_chars} --eval_batch_size=${eval_batch_size} --save_steps=${save_steps} \
        --pseudo_master_port=${pseudo_master_port} \
        --max_new_tokens=${max_new_tokens} --min_new_tokens=${min_new_tokens} \
        --search_num_beams=${search_num_beams} --max_search_steps=${max_search_steps} --num_samples_per_search_step=${num_samples_per_search_step} --max_new_token_per_step=${max_new_token_per_step} \
        --value_by_transition_scores=${value_by_transition_scores} --dedup_mode=${dedup_mode} --return_all_search_sequences=${return_all_search_sequences} --do_sample=${do_sample} --num_beams=${num_beams} \
        --data_generation_task=${data_generation_task} --end_of_step_id=${end_of_step_id} --false_eos_ids=${false_eos_ids} --role_tags=${role_tags} --role_map=${role_map}  --use_cache=${use_cache} \
        --temperature=${temperature} --top_p=${top_p} --top_k=${top_k} --ensemble_method=${ensemble_method} --length_penalty=${length_penalty} --renormalize_logits=${renormalize_logits} \
        --enable_code_interpreter=${enable_code_interpreter} --code_prefix=${code_prefix} --code_suffix=${code_suffix} --code_output_prefix=${code_output_prefix} --code_output_suffix=${code_output_suffix} --code_timeout_length=${code_timeout_length} \
        --output_answer_probs=${output_answer_probs} --starting_index=${starting_index} --filter_data_by_indices=${filter_data_by_indices} --include_by_indices=${include_by_indices} --max_mu_seq_len=${max_mu_seq_len} \
        --num_return_sequences=${num_return_sequences} --streaming=${streaming} --seed=${seed} #\
        # > ${save_path}.log 2>&1;
elif [[ ${vllm_inference} == "True" ]]; then
    echo "Using vllm engine"
    python examples/generate_lora_vllm.py \
        --base_model=${base_model} --model_name_or_path=${model_name_or_path} --load_in_8bit=${load_in_8bit} --torch_dtype=${torch_dtype} --use_flash_attn=${use_flash_attn} --merge_lora=${merge_lora} --molora_strategy=${molora_strategy} \
        --recur_strategy=${recur_strategy} --recur_times=${recur_times} --num_prelude_layers=${num_prelude_layers} --num_coda_layers=${num_coda_layers} --num_coda_layers=${num_coda_layers} \
        --input_injection_type=${input_injection_type} --state_init_strategy=${state_init_strategy} --init_std=${init_std} --attn_to_recur_key_values=${attn_to_recur_key_values} --ln_after_recur=${ln_after_recur} \
        --input_source=${input_source} --is_input_source_data_paths=${is_input_source_data_paths} --save_path=${save_path} --overwrite_save_path=$overwrite_save_path --disable_caching=${disable_caching} \
        --share_link=${share_link} --prompt_templates=${prompt_templates} --system_prompt=${system_prompt} \
        --enforce_cn_chars=${enforce_cn_chars} --eval_batch_size=${eval_batch_size} --save_steps=${save_steps} \
        --pseudo_master_port=${pseudo_master_port} \
        --max_new_tokens=${max_new_tokens} --min_new_tokens=${min_new_tokens} \
        --search_num_beams=${search_num_beams} --max_search_steps=${max_search_steps} --num_samples_per_search_step=${num_samples_per_search_step} --max_new_token_per_step=${max_new_token_per_step} \
        --value_by_transition_scores=${value_by_transition_scores} --dedup_mode=${dedup_mode} --return_all_search_sequences=${return_all_search_sequences} --do_sample=${do_sample} --num_beams=${num_beams} \
        --data_generation_task=${data_generation_task} --end_of_step_id=${end_of_step_id} --false_eos_ids=${false_eos_ids} --role_tags=${role_tags} --role_map=${role_map}  --use_cache=${use_cache} \
        --temperature=${temperature} --top_p=${top_p} --top_k=${top_k} --ensemble_method=${ensemble_method} --length_penalty=${length_penalty} --renormalize_logits=${renormalize_logits} \
        --enable_code_interpreter=${enable_code_interpreter} --code_prefix=${code_prefix} --code_suffix=${code_suffix} --code_output_prefix=${code_output_prefix} --code_output_suffix=${code_output_suffix} --code_timeout_length=${code_timeout_length} \
        --output_answer_probs=${output_answer_probs} --starting_index=${starting_index} --filter_data_by_indices=${filter_data_by_indices} --include_by_indices=${include_by_indices} --max_mu_seq_len=${max_mu_seq_len} \
        --num_return_sequences=${num_return_sequences} --streaming=${streaming} --seed=${seed} #\
        # > ${save_path}.log 2>&1;
else
    python examples/${default_generator}.py \
        --base_model=${base_model} --model_name_or_path=${model_name_or_path} --load_in_8bit=${load_in_8bit} --torch_dtype=${torch_dtype} --use_flash_attn=${use_flash_attn} --merge_lora=${merge_lora} --molora_strategy=${molora_strategy} \
        --recur_strategy=${recur_strategy} --recur_times=${recur_times} --num_prelude_layers=${num_prelude_layers} --num_coda_layers=${num_coda_layers} --num_coda_layers=${num_coda_layers} \
        --input_injection_type=${input_injection_type} --state_init_strategy=${state_init_strategy} --init_std=${init_std} --attn_to_recur_key_values=${attn_to_recur_key_values} --ln_after_recur=${ln_after_recur} \
        --input_source=${input_source} --is_input_source_data_paths=${is_input_source_data_paths} --save_path=${save_path} --overwrite_save_path=$overwrite_save_path --disable_caching=${disable_caching} \
        --share_link=${share_link} --prompt_templates=${prompt_templates} --system_prompt=${system_prompt} \
        --enforce_cn_chars=${enforce_cn_chars} --eval_batch_size=${eval_batch_size} --save_steps=${save_steps} \
        --pseudo_master_port=${pseudo_master_port} \
        --max_new_tokens=${max_new_tokens} --min_new_tokens=${min_new_tokens} \
        --search_num_beams=${search_num_beams} --max_search_steps=${max_search_steps} --num_samples_per_search_step=${num_samples_per_search_step} --max_new_token_per_step=${max_new_token_per_step} \
        --value_by_transition_scores=${value_by_transition_scores} --dedup_mode=${dedup_mode} --return_all_search_sequences=${return_all_search_sequences} --do_sample=${do_sample} --num_beams=${num_beams} \
        --data_generation_task=${data_generation_task} --end_of_step_id=${end_of_step_id} --false_eos_ids=${false_eos_ids} --role_tags=${role_tags} --role_map=${role_map}  --use_cache=${use_cache} \
        --temperature=${temperature} --top_p=${top_p} --top_k=${top_k} --ensemble_method=${ensemble_method} --length_penalty=${length_penalty} --renormalize_logits=${renormalize_logits} \
        --enable_code_interpreter=${enable_code_interpreter} --code_prefix=${code_prefix} --code_suffix=${code_suffix} --code_output_prefix=${code_output_prefix} --code_output_suffix=${code_output_suffix} --code_timeout_length=${code_timeout_length} \
        --output_answer_probs=${output_answer_probs} --starting_index=${starting_index} --filter_data_by_indices=${filter_data_by_indices} --include_by_indices=${include_by_indices} --max_mu_seq_len=${max_mu_seq_len} \
        --num_return_sequences=${num_return_sequences} --streaming=${streaming} --seed=${seed} #\
        # > ${save_path}.log 2>&1;
fi