#!/bin/bash

# Configuration variables
BASE_DIR="/home/yeq6/Research_project/llama"
LOG_DIR="${BASE_DIR}/mass-run_1"
# MODEL_PATH="${BASE_DIR}/llama-2-7b-chat_hf"
# AWQ_CACHE="${BASE_DIR}/llm-awq/awq_cache/llama2-7b-w4-g128_mine.pt"

MODEL_PATH="meta-llama/Llama-2-7b-hf"  # Use base model for all tests
AWQ_CACHE="${BASE_DIR}/llm-awq/awq_cache/llama-2-7b-w4-g128.pt"

# Search result paths (adjust these as needed)
SEARCH_RESULT_PATH="${BASE_DIR}/best_channel_scales_Llama-2-7b-hf.txt"

# Evaluation tasks
TASKS="wikitext boolq hellaswag winogrande arc_easy arc_challenge"
# TASKS="wikitext boolq piqa social_iqa hellaswag winogrande arc_easy arc_challenge openbookqa"
# TASKS="social_i_qa"

# Create log directory if it doesn't exist
mkdir -p "$LOG_DIR"

# Set environment variables
export CUDA_LAUNCH_BLOCKING=1
export TOKENIZERS_PARALLELISM=false
export TORCH_USE_CUDA_DSA=1

echo "Starting benchmark runs..."
echo "Log directory: $LOG_DIR"
echo "Model: $MODEL_PATH"
echo "Tasks: $TASKS"
echo "========================================="

# Function to run benchmark with error handling
run_benchmark() {
    local config_name="$1"
    local base_args="$2"
    local additional_args="$3"
    local log_file="${LOG_DIR}/$(basename $MODEL_PATH)_${config_name}"
    
    echo ""
    echo "Running: $config_name"
    echo "Log file: $log_file"
    echo "Base args: $base_args"
    echo "Additional args: $additional_args"
    echo "----------------------------------------"
    
    # Use quotes around TASKS to pass it as a single argument
    if python Benchmark_all.py $base_args --log_file "$log_file" --tasks "$TASKS" $additional_args; then
        echo "✅ $config_name completed successfully"
    else
        echo "❌ $config_name failed with exit code $?"
    fi
    
    # Clean up GPU memory between runs
    sleep 5
}

# =============================================================================
# FP16 (Full Precision 16-bit) Tests
# =============================================================================
echo ""
echo "=== FP16 BASELINE TESTS ==="

# # FP16 common arguments (no quantization, no AWQ)
# FP16_ARGS="-m $MODEL_PATH --aggressive-memory --original --custom-model --original-max-position-embeddings 2048"

# # 1. FP16 with no position interpolation (yarn 1)
# run_benchmark "fp16_no_pi" "$FP16_ARGS" "--yarn 1"

# # 2. FP16 baseline with yarn 8 scaling
# run_benchmark "fp16_yarn8" "$FP16_ARGS" "--yarn 8"

# =============================================================================
# RTN (Round-to-Nearest) Quantization Tests
# =============================================================================
echo ""
echo "=== RTN QUANTIZATION TESTS ==="

# RTN common arguments
RTN_ARGS="-m $MODEL_PATH --aggressive-memory --original --custom-model --original-max-position-embeddings 2048 --awq"


# 1. RTN + no scaling (yarn 1 = no position interpolation)
run_benchmark "rtn_no_pi" "$RTN_ARGS" "--naive_quant --yarn 1"

# 2. RTN baseline (with yarn 8 scaling)
run_benchmark "rtn_baseline" "$RTN_ARGS" "--naive_quant --yarn 8"

# 3. RTN + activation quantization
run_benchmark "rtn_act_quant" "$RTN_ARGS" "--naive_quant --yarn 8 --quant_activation --quant_activation_bitwidth 4"

# 4. RTN + Hadamard transformations (individual layers)
run_benchmark "rtn_hadamard_qkv" "$RTN_ARGS" "--naive_quant --yarn 8 --apply_hardmard --hardmard_layers q_proj,k_proj,v_proj"

run_benchmark "rtn_hadamard_o_proj" "$RTN_ARGS" "--naive_quant --yarn 8 --apply_hardmard --hardmard_layers o_proj"

run_benchmark "rtn_hadamard_up_proj" "$RTN_ARGS" "--naive_quant --yarn 8 --apply_hardmard --hardmard_layers up_proj"

run_benchmark "rtn_hadamard_gate_proj" "$RTN_ARGS" "--naive_quant --yarn 8 --apply_hardmard --hardmard_layers gate_proj"

# =============================================================================
# AWQ (Activation-aware Weight Quantization) Tests
# =============================================================================
echo ""
echo "=== AWQ QUANTIZATION TESTS ==="

# Check if AWQ cache exists
if [ ! -f "$AWQ_CACHE" ]; then
    echo "⚠️  Warning: AWQ cache file not found at $AWQ_CACHE"
    echo "Please ensure the AWQ cache file exists before running AWQ tests"
    exit 1
fi

# AWQ common arguments
AWQ_ARGS="-m $MODEL_PATH --aggressive-memory --original --custom-model --original-max-position-embeddings 2048 --awq"

# 1. AWQ + no scaling (yarn 1 = no position interpolation)
run_benchmark "awq_no_pi" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 1"

# 2. AWQ baseline (with yarn 8 scaling)
run_benchmark "awq_baseline" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8"

# 3. AWQ + activation quantization
run_benchmark "awq_act_quant" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --quant_activation --quant_activation_bitwidth 4"

# =============================================================================
# AWQ + Custom Rotation Tests
# =============================================================================
echo ""
echo "=== AWQ + CUSTOM ROTATION TESTS ==="

# Check if search result file exists (if using search results)
if [ -n "$SEARCH_RESULT_PATH" ] && [ ! -f "$SEARCH_RESULT_PATH" ]; then
    echo "⚠️  Warning: Search result file not found at $SEARCH_RESULT_PATH"
    echo "Running without search results..."
    SEARCH_RESULT_PATH=""
fi

# 1. AWQ + custom rotation baseline (with yarn 8 scaling)
if [ -n "$SEARCH_RESULT_PATH" ]; then
    run_benchmark "awq_custom_rotation_baseline" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --rescale_attention_all --rescale_per_head --use_search_result --search_result_path $SEARCH_RESULT_PATH"
else
    run_benchmark "awq_custom_rotation_baseline" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --rescale_attention_all --rescale_per_head"
fi

# 2. AWQ + custom rotation + no scaling (yarn 1 = no position interpolation)
if [ -n "$SEARCH_RESULT_PATH" ]; then
    run_benchmark "awq_custom_rotation_no_pi" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 1 --rescale_attention_all --rescale_per_head --use_search_result --search_result_path $SEARCH_RESULT_PATH"
else
    run_benchmark "awq_custom_rotation_no_pi" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 1 --rescale_attention_all --rescale_per_head"
fi

# 3. AWQ + custom rotation + activation quantization (with yarn 8 scaling)
if [ -n "$SEARCH_RESULT_PATH" ]; then
    run_benchmark "awq_custom_rotation_act_quant" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --rescale_attention_all --rescale_per_head --use_search_result --search_result_path $SEARCH_RESULT_PATH --quant_activation --quant_activation_bitwidth 4"
else
    run_benchmark "awq_custom_rotation_act_quant" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --rescale_attention_all --rescale_per_head --quant_activation --quant_activation_bitwidth 4"
fi

# 4. AWQ + individual channel scaling (with yarn 8 scaling)
run_benchmark "awq_individual_channel_yarn8" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --individual_channel_up \"124,90,62,52,86,108,118,75,114,120,50,93,115,73,58,81,80,94,89,69,63,51,46,106,107,84,44,42,123,68,88,15,6,28,12,17,1,20,8,35\" --individual_channel_scale 2"

# 5. AWQ + individual channel scaling + no position interpolation (yarn 1)
run_benchmark "awq_individual_channel_no_pi" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 1 --individual_channel_up \"124,90,62,52,86,108,118,75,114,120,50,93,115,73,58,81,80,94,89,69,63,51,46,106,107,84,44,42,123,68,88,15,6,28,12,17,1,20,8,35\" --individual_channel_scale 2"

# 5. AWQ + individual channel scaling + activation quantization (yarn 8)
run_benchmark "awq_individual_channel_no_pi" "$AWQ_ARGS" "--awq_cache $AWQ_CACHE --yarn 8 --individual_channel_up \"124,90,62,52,86,108,118,75,114,120,50,93,115,73,58,81,80,94,89,69,63,51,46,106,107,84,44,42,123,68,88,15,6,28,12,17,1,20,8,35\" --individual_channel_scale 2 --quant_activation --quant_activation_bitwidth 4"

# =============================================================================
# Summary
# =============================================================================
echo ""
echo "========================================="
echo "All benchmark runs completed!"
echo "Log files are stored in: $LOG_DIR"
echo ""
echo "Summary of configurations tested:"
echo "📊 FP16 Tests: 2 configurations (commented out)"
echo "📊 RTN Tests: 7 configurations"
echo "📊 AWQ Tests: 3 configurations" 
echo "📊 AWQ + Custom Rotation Tests: 3 configurations"
echo "📊 AWQ + Individual Channel Tests: 3 configurations"
echo "📊 Total: 16 configurations"
echo ""
echo "Check individual log files for detailed results."
echo "========================================="

# Optional: Create a summary report
SUMMARY_FILE="${LOG_DIR}/benchmark_summary_$(date +%Y%m%d_%H%M%S).txt"
echo "Benchmark run completed on $(date)" > "$SUMMARY_FILE"
echo "Model: $MODEL_PATH" >> "$SUMMARY_FILE"
echo "Tasks: $TASKS" >> "$SUMMARY_FILE"
echo "" >> "$SUMMARY_FILE"
echo "Configuration breakdown:" >> "$SUMMARY_FILE"
echo "- FP16 Tests: 2 configurations" >> "$SUMMARY_FILE"
echo "- RTN Tests: 7 configurations" >> "$SUMMARY_FILE"
echo "- AWQ Tests: 3 configurations" >> "$SUMMARY_FILE"
echo "- AWQ + Custom Rotation Tests: 3 configurations" >> "$SUMMARY_FILE"
echo "- Total: 15 configurations" >> "$SUMMARY_FILE"
echo "" >> "$SUMMARY_FILE"
echo "Log files generated:" >> "$SUMMARY_FILE"
ls -la "${LOG_DIR}/$(basename $MODEL_PATH)"_* >> "$SUMMARY_FILE"

echo "Summary report saved to: $SUMMARY_FILE"