export NCCL_DEBUG=WARNING

date=$(date +"%m%d")

echo "date: ${date}"

model="$1"
dataset="$2"
task_order="$3"
eval_task_order="$4"
run_name="$5"

echo "model: ${model}"
echo "dataset: ${dataset}"
echo "task_order: ${task_order}"
echo "eval_task_order: ${eval_task_order}"
echo "run_name: ${run_name}"


max_len=4096
epoch=5
mem_freeze_backbone=false
memory_insert_layers=31
memory_size=100
fix_memory=false
fusion_func="sigmoid_alpha"
use_last_prompt_token_as_key=false
update_while_predicting=false
update_strategy="attn"
memory_update_steps=200
memory_path=none
lora_rank=8


if [ -f "ckpt/cl/mem/${model}/${run_name}/eval_results/${task_order}/${eval_task_order}/generated_predictions.jsonl" ]; then
  echo "Prediction file already exists, skipping."
  exit 0
fi

echo "———————————————————— Test ————————————————————"

if [ "${model}" = "llama3" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/meta-llama/Meta-Llama-3-8B"
  memory_insert_layers=31
elif [ "${model}" = "llama3.1" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/meta-llama/Llama-3.1-8B"
  memory_insert_layers=31
elif [ "${model}" = "qwen" ]; then
  model_name_or_path="/datanfs4/name/HuggingFaceModels/Qwen/Qwen3-8B-Base"
  memory_insert_layers=35
fi

if [ "${dataset}" = "gsm8k" ]; then
  template="vanilla_qa"
else
  template="vanilla"
fi


find_free_port() {
    python3 -c "
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(('', 0))
    print(s.getsockname()[1])
"
}

main_process_port=$(find_free_port)

accelerate launch \
    --main_process_port "${main_process_port}" \
    --num_processes 1 \
    --config_file src/configs/accelerate_zero2.yaml \
    src/train_sft.py \
    --seed 0 \
    --stage sft \
    --do_predict true \
    --predict_with_generate true \
    --cutoff_len "${max_len}" \
    --max_new_tokens 1024 \
    --model_name_or_path "${model_name_or_path}" \
    --adapter_name_or_path "ckpt/cl/mem/${model}/${run_name}/${task_order}" \
    --eval_dataset "${dataset}_test" \
    --dataset_dir "data/cl" \
    --template "${template}" \
    --finetuning_type "lora" \
    --lora_target "q_proj,k_proj,v_proj,up_proj,down_proj" \
    --lora_rank "${lora_rank}" \
    --additional_target "alpha" \
    --output_dir "ckpt/cl/mem/${model}/${run_name}/eval_results/${task_order}/${eval_task_order}" \
    --overwrite_cache true \
    --per_device_eval_batch_size 1 \
    --report_to none \
    --fp16 true \
    --overwrite_output_dir true \
    --enable_mem true \
    --mem_freeze_backbone "${mem_freeze_backbone}" \
    --memory_insert_layers "${memory_insert_layers}" \
    --update_memory true \
    --memory_size "${memory_size}" \
    --use_gpu_to_search true \
    --init_from_backbone true \
    --init_cross_attn_from_self true \
    --fix_memory "${fix_memory}" \
    --fusion_func "${fusion_func}" \
    --use_last_prompt_token_as_key "${use_last_prompt_token_as_key}" \
    --update_while_predicting "${update_while_predicting}" \
    --update_strategy "${update_strategy}" \
    --memory_update_steps "${memory_update_steps}" \
    --memory_path "${memory_path}"