export NCCL_DEBUG=WARNING

date=$(date +"%m%d")

echo "date: ${date}"

model="$1"
dataset="$2"
task_order="$3"
run_name="$4"


max_len=4096
epoch=5
mem_freeze_backbone=false
memory_size=100
fix_memory=false
fusion_func="sigmoid_alpha"
use_last_prompt_token_as_key=false
update_while_predicting=false
update_strategy="attn"
memory_update_steps=200
memory_path=none
lora_rank=8


echo "———————————————————— Train ————————————————————"


if [ "${model}" = "llama3" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/meta-llama/Meta-Llama-3-8B"
  memory_insert_layers=31
elif [ "${model}" = "llama3.1" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/meta-llama/Llama-3.1-8B"
  memory_insert_layers=31
elif [ "${model}" = "qwen" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/Qwen/Qwen3-8B-Base"
  memory_insert_layers=35
fi

adapter_arg=()
if (( task_order > 1 )); then
  prev=$(( task_order - 1 ))
  adapter_path="ckpt/cl/mem/${model}/${run_name}/${prev}"
  adapter_arg+=( --adapter_name_or_path "$adapter_path" )
fi


if [ "${dataset}" = "gsm8k" ]; then
  template="vanilla_qa"
else
  template="vanilla"
fi

echo "model: ${model}"
echo "epoch: ${epoch}"
echo "max_len: ${max_len}"
echo "dataset: ${dataset}"
echo "task_order: ${task_order}"
echo "run_name: ${run_name}"
echo "lora_rank: ${lora_rank}"
echo "adapter_name_or_path: ${adapter_name_or_path}"

find_free_port() {
    python3 -c "
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(('', 0))
    print(s.getsockname()[1])
"
}

main_process_port=$(find_free_port)

accelerate launch \
    --num_processes 1 \
    --main_process_port "${main_process_port}" \
    --config_file src/configs/accelerate_zero2.yaml \
    src/train_sft.py \
        --seed 0 \
        --stage sft \
        --do_train true \
        --per_device_eval_batch_size 1 \
        --predict_with_generate false \
        --num_train_epochs "${epoch}" \
        --cutoff_len "${max_len}" \
        --model_name_or_path "${model_name_or_path}" \
        "${adapter_arg[@]}" \
        --dataset "${dataset}_train" \
        --dataset_dir "data/cl" \
        --template "${template}" \
        --mask_history true \
        --finetuning_type "lora" \
        --lora_target "q_proj,k_proj,v_proj,up_proj,down_proj" \
        --lora_rank "${lora_rank}" \
        --additional_target "alpha" \
        --output_dir "ckpt/cl/mem/${model}/${run_name}/${task_order}" \
        --overwrite_cache true \
        --per_device_train_batch_size 1 \
        --gradient_accumulation_steps 16 \
        --learning_rate 3e-4 \
        --lr_scheduler_type "linear" \
        --num_train_epochs "${epoch}" \
        --save_total_limit "${epoch} " \
        --logging_strategy "steps" \
        --logging_steps 10 \
        --fp16 true \
        --overwrite_output_dir true \
        --save_strategy "epoch" \
        --overwrite_output_dir true \
        --enable_mem true \
        --mem_freeze_backbone "${mem_freeze_backbone}" \
        --memory_insert_layers "${memory_insert_layers}" \
        --update_memory true \
        --memory_size "${memory_size}" \
        --use_gpu_to_search true \
        --init_from_backbone true \
        --init_cross_attn_from_self true \
        --fix_memory "${fix_memory}" \
        --fusion_func "${fusion_func}" \
        --use_last_prompt_token_as_key "${use_last_prompt_token_as_key}" \
        --update_while_predicting "${update_while_predicting}" \
        --update_strategy "${update_strategy}" \
        --memory_update_steps "${memory_update_steps}" \
        --memory_path "${memory_path}" \
        --report_to none