ROOT=/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation
eval "$(/mnt/shared-storage-user/p1-shared/wangfuting/miniconda3/bin/conda shell.bash hook)"
conda activate verl041-test

# DATA=$ROOT/data/luffy/valid.all.parquet
# DATA=$ROOT/data/luffy/valid.all_qwen3.parquet
# DATA=$ROOT/data/luffy/valid-polaris-qwen3.parquet
# DATA=/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_128.parquet
# DATA=/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_8.parquet
DATA="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_unique30.parquet"
# DATA=/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/qwen3-4b-s1-sampled1k.parquet
# DATA=$ROOT/data/luffy/openr1.parquet
OUTPUT_DIR=$ROOT/results_apr
mkdir -p "$OUTPUT_DIR"
cd "$ROOT"

# Search mode:
# 1. Sweep budget_forcing settings for each model.
# 2. Read "average length" from logs.
# 3. Report the pair of settings whose realized average lengths are closest.
#
# Default target is around the paper's AIME24 overall response length scale (about 9k-10k tokens),
# but the final selection is based on minimizing the gap between the two models.
# Keep the default grid intentionally small so one sweep is affordable.
declare -a SEARCH_BUDGETS=(8192 10000 12288)
declare -a SEARCH_NUM_IGNORES=(2 4)
IGNORE_STR=" Wait"
GENERATION_MODE="budget_forcing"
FORCE_GENERATE=True
TEMPERATURE=0.0
TOP_P=1.0
FINAL_MAX_TOKENS=256
TARGET_AVG_LENGTH=9500

# 定义模型路径和对应名称
declare -a MODEL_PATHS=(
 "/mnt/shared-storage-gpfs2/p1-shared-2/wangfuting/LIE/models/verl-qwen3-4b-oct/baseline/best_model/actor/huggingface"
 "/mnt/shared-storage-gpfs2/p1-shared-2/wangfuting/LIE/models/verl-qwen3-4b-oct/LIE/best_model/actor/huggingface"
)

declare -a MODEL_NAMES=(
    "baseline-aime-s1"
    "LIE-aime-s1"
)

declare -a TEMPLATES=(
    "own"
    "own"
)

if [[ "${#MODEL_PATHS[@]}" -ne "${#MODEL_NAMES[@]}" ]] || [[ "${#MODEL_PATHS[@]}" -ne "${#TEMPLATES[@]}" ]]; then
    echo "MODEL_PATHS / MODEL_NAMES / TEMPLATES 长度不一致，请检查配置。" >&2
    exit 1
fi

if [[ "${#MODEL_PATHS[@]}" -ne 2 ]]; then
    echo "当前脚本的自动配长汇总逻辑默认按两个模型写的，请先保持 MODEL_PATHS 长度为 2。" >&2
    exit 1
fi

SUMMARY_FILE="$OUTPUT_DIR/s1_length_search_summary.tsv"
BEST_FILE="$OUTPUT_DIR/s1_length_search_best.txt"
rm -f "$SUMMARY_FILE" "$BEST_FILE"
printf "model_name\tbudget\tnum_ignore\tavg_length\tacc\tlog_file\n" > "$SUMMARY_FILE"

extract_metric() {
    local log_file="$1"
    local metric_name="$2"
    python - "$log_file" "$metric_name" <<'PY'
import re
import sys
from pathlib import Path

log_path = Path(sys.argv[1])
metric_name = sys.argv[2]
text = log_path.read_text(encoding="utf-8", errors="ignore")
matches = re.findall(rf"{re.escape(metric_name)}:\s*([0-9]+(?:\.[0-9]+)?)", text)
print(matches[-1] if matches else "")
PY
}

for i in "${!MODEL_PATHS[@]}"; do
    MODEL_PATH="${MODEL_PATHS[$i]}"
    MODEL_NAME="${MODEL_NAMES[$i]}"
    TEMPLATE="${TEMPLATES[$i]}"

    echo "正在评估模型: $MODEL_NAME"
    echo "模型路径: $MODEL_PATH"

    for budget in "${SEARCH_BUDGETS[@]}"; do
        for num_ignore in "${SEARCH_NUM_IGNORES[@]}"; do
            RUN_TAG="${MODEL_NAME}_bf_wait${num_ignore}x_${budget}_final${FINAL_MAX_TOKENS}"
            OUTPUT_FILE="$OUTPUT_DIR/${RUN_TAG}_test.jsonl"
            LOG_FILE="$OUTPUT_DIR/${RUN_TAG}.log"

            echo "开始生成，budget=$budget, num_ignore=$num_ignore, final_max_tokens=$FINAL_MAX_TOKENS, force_generate=$FORCE_GENERATE"

            python eval_scripts/generate_vllm.py \
                --model_path "$MODEL_PATH" \
                --input_file "$DATA" \
                --output_file "$OUTPUT_FILE" \
                --remove_system True \
                --generation_mode "$GENERATION_MODE" \
                --no-split-think True \
                --length_budget "$budget" \
                --max_tokens "$budget" \
                --final_max_tokens "$FINAL_MAX_TOKENS" \
                --temperature "$TEMPERATURE" \
                --num_ignore "$num_ignore" \
                --top_p "$TOP_P" \
                --ignore_str "$IGNORE_STR" \
                --enable_thinking False \
                --n 1 \
                --template "$TEMPLATE" \
                --force_generate "$FORCE_GENERATE" > "$LOG_FILE"

            AVG_LENGTH="$(extract_metric "$LOG_FILE" "average length")"
            AVG_ACC="$(extract_metric "$LOG_FILE" "avg acc")"
            printf "%s\t%s\t%s\t%s\t%s\t%s\n" \
                "$MODEL_NAME" "$budget" "$num_ignore" "$AVG_LENGTH" "$AVG_ACC" "$LOG_FILE" >> "$SUMMARY_FILE"

            echo "模型 $MODEL_NAME 评估完成: budget=$budget, num_ignore=$num_ignore, average_length=$AVG_LENGTH, avg_acc=$AVG_ACC"
        done
    done

    echo "----------------------------------------"
done

python - "$SUMMARY_FILE" "$BEST_FILE" "$TARGET_AVG_LENGTH" <<'PY'
import csv
import itertools
import math
import sys
from pathlib import Path

summary_file = Path(sys.argv[1])
best_file = Path(sys.argv[2])
target = float(sys.argv[3])

rows = []
with summary_file.open("r", encoding="utf-8") as f:
    reader = csv.DictReader(f, delimiter="\t")
    for row in reader:
        if not row["avg_length"]:
            continue
        row["budget"] = int(row["budget"])
        row["num_ignore"] = int(row["num_ignore"])
        row["avg_length"] = float(row["avg_length"])
        row["acc"] = float(row["acc"]) if row["acc"] else float("nan")
        rows.append(row)

if len(rows) == 0:
    raise SystemExit("No valid rows found in summary file.")

grouped = {}
for row in rows:
    grouped.setdefault(row["model_name"], []).append(row)

if len(grouped) != 2:
    raise SystemExit("Expected exactly two models in summary file.")

model_names = list(grouped.keys())
model_a_rows = grouped[model_names[0]]
model_b_rows = grouped[model_names[1]]

best_pair = None
best_key = None
for row_a, row_b in itertools.product(model_a_rows, model_b_rows):
    length_gap = abs(row_a["avg_length"] - row_b["avg_length"])
    target_gap = abs(row_a["avg_length"] - target) + abs(row_b["avg_length"] - target)
    acc_score = -(row_a["acc"] + row_b["acc"])
    key = (length_gap, target_gap, acc_score)
    if best_key is None or key < best_key:
        best_key = key
        best_pair = (row_a, row_b)

assert best_pair is not None
row_a, row_b = best_pair

lines = [
    f"target_avg_length={target}",
    f"best_length_gap={abs(row_a['avg_length'] - row_b['avg_length']):.2f}",
    f"{row_a['model_name']}\tbudget={row_a['budget']}\tnum_ignore={row_a['num_ignore']}\tavg_length={row_a['avg_length']:.2f}\tavg_acc={row_a['acc']:.4f}\tlog={row_a['log_file']}",
    f"{row_b['model_name']}\tbudget={row_b['budget']}\tnum_ignore={row_b['num_ignore']}\tavg_length={row_b['avg_length']:.2f}\tavg_acc={row_b['acc']:.4f}\tlog={row_b['log_file']}",
]
best_file.write_text("\n".join(lines) + "\n", encoding="utf-8")
print("\n".join(lines))
PY

echo "所有模型评估完成！"
echo "汇总文件: $SUMMARY_FILE"
echo "最佳配长: $BEST_FILE"
