#!/bin/bash
# Usage: bash scripts/eval_aloha.sh

WORK_DIR=${WORK_DIR:-"$(pwd)"}
cd "$WORK_DIR"

if ! module load miniforge3/24.11 2>/dev/null; then
    source ~/.bashrc 2>/dev/null || true
fi

source $(conda info --base)/etc/profile.d/conda.sh
module load cuda/12.8 2>/dev/null || true
source activate bagel

export PYTHONPATH="${WORK_DIR}:${PYTHONPATH}"

export CUDA_VISIBLE_DEVICES=0
GPUS=1

base_dir=${base_dir:-"./models/BAGEL-7B-MoT"}
model_path=${model_path:-"./checkpoints/aloha/plug_cable/ckpt/0005000"}
output_base=${output_base:-"./eval_result/aloha/plug_cable"}

dataset_name=${dataset_name:-"plug_cable"}
data_base_dir=${data_base_dir:-"./data/bagel_data/dynamics/plug_cable"}

eval_mode=${eval_mode:-"dynamics"}  # dynamics, vlm_reward, or all
num_samples=${num_samples:-50}

max_mem_per_gpu=${max_mem_per_gpu:-"80GiB"}

#   1. [head, left, right][t] + action -> next_head
#   2. [head, left, right, next_head][t] + action -> next_left_wrist
#   3. [head, left, right, next_head][t] + action -> next_right_wrist
dynamics_jsonl="${data_base_dir}/aloha_dynamics.jsonl"
dynamics_prompt="${data_base_dir}/dynamics_prompt.txt"
dynamics_image_dir="${data_base_dir}/dynamics_images"

vlm_reward_jsonl="${data_base_dir}/aloha_vlm_reward.jsonl"
vlm_reward_prompt="${data_base_dir}/vlm_reward_prompt.txt"
vlm_reward_image_dir="${data_base_dir}/vlm_images"

mkdir -p "$output_base"

echo ""
echo "=========================================="
echo "ALOHA 评估脚本"
echo "=========================================="
echo "数据集: $dataset_name"
echo "评估模式: $eval_mode"
echo "GPUs: $CUDA_VISIBLE_DEVICES ($GPUS GPUs)"
echo "模型路径: $model_path"
echo "输出目录: $output_base"
echo "样本数: $num_samples"
echo "每 GPU 最大内存: $max_mem_per_gpu"
echo ""
echo "数据配置:"
echo "  - 数据基础目录: $data_base_dir"
if [ "$eval_mode" = "dynamics" ] || [ "$eval_mode" = "all" ]; then
    echo "  - Dynamics JSONL: $dynamics_jsonl"
    echo "  - Dynamics Prompt: $dynamics_prompt"
    echo "  - Dynamics Images: $dynamics_image_dir"
fi
if [ "$eval_mode" = "vlm_reward" ] || [ "$eval_mode" = "all" ]; then
    echo "  - VLM Reward JSONL: $vlm_reward_jsonl"
    echo "  - VLM Reward Prompt: $vlm_reward_prompt"
    echo "  - VLM Reward Images: $vlm_reward_image_dir"
fi
echo "=========================================="

if [ "$eval_mode" = "dynamics" ] || [ "$eval_mode" = "all" ]; then
    echo ""
    echo "【1/2】运行 Dynamics 评估（多图像格式）..."
    echo "=========================================="
    
    dynamics_output="$output_base/dynamics"
    mkdir -p "$dynamics_output"
    
    if [ ! -f "$dynamics_jsonl" ]; then
        echo "警告: Dynamics JSONL 文件不存在: $dynamics_jsonl"
        echo "跳过 dynamics 评估"
    else
        echo "使用 Dynamics 数据: $dynamics_jsonl"
        
        torchrun \
            --nnodes=1 \
            --node_rank=0 \
            --nproc_per_node=$GPUS \
            --master_addr=127.0.0.1 \
            --master_port=12345 \
            ./eval/gen_aloha_dynamics.py \
            --base_dir "$base_dir" \
            --model_path "$model_path" \
            --output_dir "$dynamics_output" \
            --jsonl_path "$dynamics_jsonl" \
            --prompt_path "$dynamics_prompt" \
            --image_dir "$dynamics_image_dir" \
            --cfg_text_scale 6.0 \
            --cfg_img_scale 2.0 \
            --num_samples $num_samples \
            --max_mem_per_gpu "$max_mem_per_gpu"
        
        if [ $? -eq 0 ]; then
            echo "✓ Dynamics 评估完成"
        else
            echo "✗ Dynamics 评估失败"
        fi
    fi
fi

if [ "$eval_mode" = "vlm_reward" ] || [ "$eval_mode" = "all" ]; then
    echo ""
    echo "【2/2】运行 VLM Reward 评估（任务判断）..."
    echo "=========================================="
    
    vlm_reward_output="$output_base/vlm_reward"
    mkdir -p "$vlm_reward_output"
    
    if [ ! -f "$vlm_reward_jsonl" ]; then
        echo "警告: VLM Reward JSONL 文件不存在: $vlm_reward_jsonl"
        echo "跳过 vlm reward 评估"
    else
        echo "使用 VLM Reward 数据: $vlm_reward_jsonl"
        
        torchrun \
            --nnodes=1 \
            --node_rank=0 \
            --nproc_per_node=$GPUS \
            --master_addr=127.0.0.1 \
            --master_port=12347 \
            ./eval/gen_aloha_vlm_reward.py \
            --base_dir "$base_dir" \
            --model_path "$model_path" \
            --output_dir "$vlm_reward_output" \
            --jsonl_path "$vlm_reward_jsonl" \
            --prompt_path "$vlm_reward_prompt" \
            --image_dir "$vlm_reward_image_dir" \
            --num_samples $num_samples \
            --max_new_tokens 10 \
            --max_mem_per_gpu "$max_mem_per_gpu"
        
        if [ $? -eq 0 ]; then
            echo "✓ VLM Reward 评估完成"
        else
            echo "✗ VLM Reward 评估失败"
        fi
    fi
fi

echo ""
echo "=========================================="
echo "评估结果汇总"
echo "=========================================="
echo "输出目录: $output_base"

if [ "$eval_mode" = "dynamics" ] || [ "$eval_mode" = "all" ]; then
    if [ -f "$output_base/dynamics/evaluation_results.json" ]; then
        echo ""
        echo "【Dynamics 评估结果（时间预测）】"
        python -c "
import json
try:
    with open('$output_base/dynamics/evaluation_results.json', 'r') as f:
        results = json.load(f)
        print(json.dumps(results, indent=2))
except Exception as e:
    print(f'无法读取结果: {e}')
"
    fi
fi

if [ "$eval_mode" = "vlm_reward" ] || [ "$eval_mode" = "all" ]; then
    if [ -f "$output_base/vlm_reward/evaluation_summary.json" ]; then
        echo ""
        echo "【VLM Reward 评估结果（任务判断）】"
        python -c "
import json
try:
    with open('$output_base/vlm_reward/evaluation_summary.json', 'r') as f:
        results = json.load(f)
        print(f\"总样本数: {results['total_samples']}\")
        print(f\"整体准确率: {results['overall_accuracy']*100:.2f}%\")
        print(f\"Yes 准确率: {results['yes_accuracy']*100:.2f}%\")
        print(f\"No 准确率: {results['no_accuracy']*100:.2f}%\")
except Exception as e:
    print(f'无法读取结果: {e}')
"
    fi
fi

echo ""
echo "=========================================="
echo "评估完成！"
echo ""
echo "查看详细结果:"
if [ "$eval_mode" = "dynamics" ] || [ "$eval_mode" = "all" ]; then
    echo "  - Dynamics (多图像时间预测): $output_base/dynamics/"
    echo "    · 对比图像: $output_base/dynamics/*_comparison.png"
    echo "    · 评估指标: $output_base/dynamics/evaluation_results.json"
    echo "    · 按预测类型分组: by_prediction_type (head/left_wrist/right_wrist)"
fi
if [ "$eval_mode" = "vlm_reward" ] || [ "$eval_mode" = "all" ]; then
    echo "  - VLM Reward (任务判断): $output_base/vlm_reward/"
    echo "    · 分类结果: $output_base/vlm_reward/detailed_results.json"
    echo "    · 准确率: $output_base/vlm_reward/evaluation_summary.json"
fi
echo "=========================================="

