#!/bin/bash
#SBATCH --job-name=trace_option_contrib_llama3-8b
#SBATCH --partition=lvjq
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=3
#SBATCH --gres=gpu:1
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

PYTHON=~/.conda/envs/come/bin/python

# 定义基础路径
TASK=arc_easy
BASE_PATH="/TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/prun"
OUTPUT_BASE="/TO/MY/PATH/code/Understanding_Performance_Collapse/tools/results/2_1-New-llama3-8b-instruct/results_option_logits/${TASK}"

# 定义数组
PR_RATIOS=("Dense" "PR_12.5%" "PR_37.5%" "PR_50%")
MODEL_PATHS=(
    "/seu_nvme/ogai/models/Meta-Llama-3.1-8B-Instruct"
    "${BASE_PATH}/ContinuePrun-from-ShortGPT-31Layer/Meta-Llama-3.1-8B-Instruct_shortgpt_28"
    "${BASE_PATH}/ContinuePrun-from-ShortGPT-24Layer_1/Meta-Llama-3.1-8B-Instruct_shortgpt_24_shortgpt_20"
    "${BASE_PATH}/ContinuePrun-from-ShortGPT-24Layer_1/Meta-Llama-3.1-8B-Instruct_shortgpt_24_shortgpt_16"
)

DENSE_LAYERS=(32 32 32 32)

# 循环处理
for i in "${!PR_RATIOS[@]}"; do
    PR_RATIO="${PR_RATIOS[i]}"
    MODEL_PATH="${MODEL_PATHS[i]}"
    LAYERS="${DENSE_LAYERS[i]}"
    OUTPUT_DIR="${OUTPUT_BASE}/${PR_RATIO}/"
    
    echo "========================================"
    echo "Processing: ${PR_RATIO}"
    echo "Model path: ${MODEL_PATH}"
    echo "Output dir: ${OUTPUT_DIR}"
    echo "Dense layers: ${LAYERS}"
    echo "========================================"
    
    ${PYTHON} /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/trace_layerwise_option_logits.py \
        --model_name_or_path "${MODEL_PATH}" \
        --sft_dataset ${TASK} \
        --eval_split validation \
        --num_eval_samples 500 \
        --output_dir "${OUTPUT_DIR}" \
        --dense_layers "${LAYERS}"
    
    echo ""
done

echo "All llama3-8b-instruct models processed!"
