#!/bin/bash
#SBATCH --job-name=trace_hellaswag_option_contrib
#SBATCH --partition=lvjq
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=3
#SBATCH --gres=gpu:1
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

PYTHON=~/.conda/envs/come/bin/python

# -----------------------------
# Models (keep consistent with your reference bash)
# -----------------------------
DENSE=/seu_nvme/ogai/models/Meta-Llama-3.1-8B-Instruct

# pruned models base (keep consistent with your reference bash)
PRUNED_BASE=/TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/prun/ContinuePrun-from-ShortGPT-24Layer/

LAYERS=(
  Meta-Llama-3.1-8B-Instruct_shortgpt_24_shortgpt_20
)

# -----------------------------
# Output base dir (change the tail folder name as you like)
# -----------------------------
OUT_BASE=/TO/MY/PATH/code/Understanding_Performance_Collapse/tools/results_trace_option_contrib/arc_challenge/

# -----------------------------
# Trace script path (replace with the real path you saved)
# -----------------------------
TRACE_PY=/TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/trace_option_contrib.py

# -----------------------------
# Trace options
# -----------------------------
TRACE_INDEX=10             # which sample to trace (global index in dataset)
MAX_NEW_TOKENS=64          # stop if no A/B/C/D predicted within this many decode steps

mkdir -p "$OUT_BASE"

echo "======================================================"
echo "[Job] trace option contribution"
echo "DENSE=$DENSE"
echo "OUT_BASE=$OUT_BASE"
echo "DS_CONFIG=$DS_CONFIG"
echo "TRACE_PY=$TRACE_PY"
echo "TRACE_INDEX=$TRACE_INDEX"
echo "MAX_NEW_TOKENS=$MAX_NEW_TOKENS"
echo "======================================================"

# # -----------------------------
# # Run: dense model (optional but recommended as baseline)
# # -----------------------------
# OUT_DENSE="${OUT_BASE}/DENSE_Meta-Llama-3.1-8B-Instruct/smaple_${TRACE_INDEX}"
# mkdir -p "$OUT_DENSE"

# echo "======================================================"
# echo "[Run Dense] MODEL=$DENSE"
# echo "OUT=$OUT_DENSE"
# echo "======================================================"

# ${PYTHON} "$TRACE_PY" \
#   --model_name_or_path "$DENSE" \
#   --output_dir "$OUT_DENSE" \
#   --sft_dataset arc_challenge \
#   --eval_split test \
#   --max_length 512 \
#   --trace_sample_index "$TRACE_INDEX" \
#   --max_new_tokens_trace "$MAX_NEW_TOKENS" \
#   --dtype bf16 \
#   --force_eager_attn

# echo "[Done Dense]"

# -----------------------------
# Run: pruned models sweep
# -----------------------------
for i in "${!LAYERS[@]}"; do
    layer="${LAYERS[$i]}"
    PRUNED="${PRUNED_BASE}/${layer}/"
    OUT="${OUT_BASE}/${layer}/smaple_${TRACE_INDEX}"

    echo "======================================================"
    echo "[Run $((i+1))/${#LAYERS[@]}] layer=${layer}"
    echo "PRUNED=${PRUNED}"
    echo "OUT=${OUT}"
    echo "======================================================"

    mkdir -p "$OUT"

    ${PYTHON} "$TRACE_PY" \
        --model_name_or_path "$PRUNED" \
        --output_dir "$OUT" \
        --sft_dataset arc_challenge \
        --eval_split validation \
        --max_length 512 \
        --trace_sample_index "$TRACE_INDEX" \
        --max_new_tokens_trace "$MAX_NEW_TOKENS" \
        --dtype bf16 \
        --force_eager_attn

    done

echo "[Done] All sweeps finished."
