#!/bin/bash
export HF_ENDPOINT=https://hf-mirror.com

TRAIN_DATASET=math_10k
TEST_DATASETS=("AQuA" "mawps" "SVAMP" "gsm8k")  
METHOD_TYPE=explore
MODEL=meta-llama/Llama-3.1-8B
EXPLORE_MODEL_A=meta-llama/Llama-3.1-8B,meta-llama/Llama-3.1-8B
NUM_EPOCHS_LIST=(2)
ADAPTER=lora
LR_RATE=2e-4
LR_RATE_A=3e-4
CUTOFF_LEN=256
EXPLORE_FLAG=1
EXPLORE_LOGITS_FACTOR=100
LAMBDAS=("0.3,0.3") # 
CLIP_VALUES=(0.1)
SEEDS=(1)
DECODING=greedy
TOP_K=-1
WANDB_PROJECT=llmboost-arithmetic-8b-8b-8b-llama
MODEL_NAME=llmboost-arithmetic-8b-8b-8b

TOPK_LIST=(1)
gpus="0,0,0"
ALPHAS=(0.1) 
BETAS=(0.1) 
export TOKENIZERS_PARALLELISM=false

# ========== 函数：训练 ==========
function run_train {
  TOPK_LOGITS=$1
  EPOCHS=$2
  CLIP_VALUE=$3
  SEED=$4
  ALPHA=$5
  BETA=$6
  echo seed!!!!$SEED
  
  OUTDIR=./trained_models/$METHOD_TYPE/$TRAIN_DATASET/$EPOCHS/${MODEL_NAME}-${CLIP_VALUE}-${ADAPTER}-${LR_RATE}/
  OUTDIR_COPILOT=${OUTDIR}Copilot-1/,${OUTDIR}Copilot-2/

  mkdir -p "$OUTDIR" "$OUTDIR_COPILOT" ./logs/$TRAIN_DATASET/$EPOCHS/$METHOD_TYPE/meta-llama/

  echo "🚀 [TRAIN] TOPK=$TOPK_LOGITS on GPU $GPU_ID"
  echo "train log file:logs/${MODEL_NAME}-${TRAIN_DATASET}_train_topk${TOPK_LOGITS}_epoch${EPOCHS}.log"
  CUDA_VISIBLE_DEVICES=0 python finetune_llmboost_llama.py \
    --base_model $MODEL \
    --explore_base_model $EXPLORE_MODEL_A \
    --data_path ./ft-training_set/${TRAIN_DATASET}.json \
    --test_data_path ./dataset/$TRAIN_DATASET/test.json \
    --output_dir "$OUTDIR" \
    --explore_output_dir "$OUTDIR_COPILOT" \
    --batch_size 16 \
    --micro_batch_size 4 \
    --num_epochs $EPOCHS \
    --learning_rate $LR_RATE \
    --learning_rate_a $LR_RATE_A \
    --cutoff_len $CUTOFF_LEN \
    --val_set_size 5 \
    --eval_step 800 \
    --save_step 800 \
    --adapter_name $ADAPTER \
    --target_modules '["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"]' \
    --lora_r 32 \
    --lora_alpha 64 \
    --explore_flag $EXPLORE_FLAG \
    --explore_logits_factor $EXPLORE_LOGITS_FACTOR \
    --direct_test False \
    --topk-logits $TOPK_LOGITS \
    --clip_value $CLIP_VALUE \
    --wandb_run_name "train-${MODEL_NAME}-${EXPLORE_FLAG}-seed-${SEED}" \
    --wandb_project $WANDB_PROJECT \
    --seed $SEED \
    --alpha $ALPHA \
    --beta $BETA \
    --model_type $MODEL_NAME \
    --gpus $gpus
}

# ========== 函数：评估 ==========
function run_eval {
  TOPK_LOGITS=$1
  EPOCHS=$2
  CLIP_VALUE=$3
  OUTDIR=./trained_models/$METHOD_TYPE/$TRAIN_DATASET/$EPOCHS/${MODEL_NAME}-${CLIP_VALUE}-${ADAPTER}-${LR_RATE}/
  OUTDIR_COPILOT=${OUTDIR}Copilot-1/,${OUTDIR}Copilot-2/

  for TEST_DATASET in "${TEST_DATASETS[@]}"; do
    for lambda_value in "${LAMBDAS[@]}"; do
      LOG_DIR="logs/$TRAIN_DATASET/$EPOCHS/$METHOD_TYPE/${TEST_DATASET}/${MODEL_NAME}"
      LOG_FILE="$LOG_DIR/${MODEL_NAME}-${ADAPTER}-${TEST_DATASET}-${METHOD_TYPE}-${DECODING}-EXPLORE_FLAG-${EXPLORE_FLAG}-Copilot-lambda${lambda_value}.log"
      mkdir -p "$LOG_DIR"

      echo "🧪 [EVAL] TOPK=$TOPK_LOGITS, lambda=$lambda_value on GPU $GPU_ID -> $LOG_FILE"

      CUDA_VISIBLE_DEVICES=0 python commonsense_evaluate_llmboost_llama.py \
        --model $MODEL \
        --adapter $ADAPTER \
        --dataset $TEST_DATASET \
        --base_model $MODEL \
        --lora_weights "$OUTDIR" \
        --explore_model $EXPLORE_MODEL_A \
        --base_explore_model $EXPLORE_MODEL_A \
        --explore_lora_weights "$OUTDIR_COPILOT" \
        --explore_flag $EXPLORE_FLAG \
        --explore_weight $lambda_value \
        --decoding_method $DECODING \
        --explore_logits_factor $EXPLORE_LOGITS_FACTOR \
        --top_k_explore_logits $TOP_K \
        --topk_logits $TOPK_LOGITS \
        --wandb_run_name "eval-${MODEL_NAME}-${EXPLORE_FLAG}-lambda${lambda_value}-seed-${SEED}" \
        --wandb_project $WANDB_PROJECT \
        --num_epochs $EPOCHS \
        --seed $SEED \
        --alpha $ALPHA \
        --beta $BETA \
        --gpus $gpus \
        --model_type $MODEL_NAME 
    done
  done
}

for NUM_EPOCHS in "${NUM_EPOCHS_LIST[@]}"; do
  for CLIP_VALUE in "${CLIP_VALUES[@]}"; do  
    for i in "${!TOPK_LIST[@]}"; do
      for ALPHA in "${ALPHAS[@]}"; do
        for BETA in "${BETAS[@]}"; do
          for SEED in "${SEEDS[@]}"; do 
            TOPK=${TOPK_LIST[$i]}
            run_train $TOPK  $NUM_EPOCHS $CLIP_VALUE $SEED $ALPHA $BETA
            run_eval $TOPK $NUM_EPOCHS $CLIP_VALUE $SEED
          done
        done
      done
    done
  done
done