#!/bin/bash
# ALFWorld Rollout分析
# 用法: ./run_rollout.sh [step] [mode] [algorithm]

# ============ 默认配置 ============
DEFAULT_MERGED_DIR="./merged_models/zju_migpo_qwen2.5_1.5b"
DEFAULT_STEP="80"
MODE="training"
ALGORITHM="migpo"
MAX_STEPS=30
HISTORY_LENGTH=30
# =================================

export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-1}
cd "$(dirname "$0")"

STEPS=${1:-$DEFAULT_STEP}
MODE=${2:-$MODE}
ALGORITHM=${3:-$ALGORITHM}

# 模式参数
if [ "$MODE" = "training" ]; then
    NUM_ENVS=16
    GROUP_SIZE=8
    SPLIT="train"
else
    NUM_ENVS=128
    GROUP_SIZE=1
    SPLIT="valid_seen"
fi

# 算法参数
case $ALGORITHM in
    grpo)
        ALGO_ARGS=""
        ;;
    gigpo)
        ALGO_ARGS="--step_advantage_w 10.0 --migpo_threshold 0.95"
        ;;
    migpo)
        ALGO_ARGS="--step_advantage_w 1.0 --migpo_gamma 0.9 --migpo_threshold 0.95"
        ;;
    *)
        echo "未知算法: $ALGORITHM"
        exit 1
        ;;
esac

echo "============================================"
echo "ALFWorld Rollout分析"
echo "============================================"
echo "模型目录: $DEFAULT_MERGED_DIR"
echo "Steps: $STEPS"
echo "模式: $MODE (envs=$NUM_ENVS, group=$GROUP_SIZE)"
echo "算法: $ALGORITHM"
echo "数据集: $SPLIT"
echo "GPU: $CUDA_VISIBLE_DEVICES"
echo "============================================"

# 确定要分析的step列表
if [ "$STEPS" = "all" ]; then
    STEP_LIST=$(ls -d "$DEFAULT_MERGED_DIR"/step_* 2>/dev/null | xargs -n1 basename | sed 's/step_//' | sort -n)
else
    STEP_LIST=$(echo "$STEPS" | tr ',' ' ')
fi

for step in $STEP_LIST; do
    model_path="${DEFAULT_MERGED_DIR}/step_${step}"

    if [ ! -d "$model_path" ]; then
        echo "警告: $model_path 不存在，跳过"
        continue
    fi

    echo ""
    echo "========== Step $step =========="
    python analyze_rollout.py \
        --merged_dir "$DEFAULT_MERGED_DIR" \
        --model_path "$model_path" \
        --split "$SPLIT" \
        --algorithm "$ALGORITHM" \
        --num_envs $NUM_ENVS \
        --group_size $GROUP_SIZE \
        --max_steps $MAX_STEPS \
        --history_length $HISTORY_LENGTH \
        $ALGO_ARGS
done

echo ""
echo "============================================"
echo "完成!"
echo "============================================"
