#!/usr/bin/env python3
"""
使用人类真实演示轨迹测试 WebShop 里程碑检测器

该脚本在 WebShop 环境中重演人类专家轨迹，并使用里程碑检测器
跟踪每个步骤的里程碑达成情况。
"""

import json
import sys
from typing import List, Dict, Any, Optional, Tuple
from collections import defaultdict

# 添加 WebShop 路径
sys.path.insert(0, 'agent_system/environments/env_package/webshop/webshop')
sys.path.insert(0, 'agent_system/environments/env_package/webshop/webshop/web_agent_site/envs')

# 直接导入模块文件
import web_agent_text_env
WebAgentTextEnv = web_agent_text_env.WebAgentTextEnv

from migpo.webshop_milestone_detector import (
    MilestoneDetector,
    MilestonePhase,
    MilestoneResult,
)

# 轨迹文件路径
TRAJ_FILE_PATH = "agent_system/environments/env_package/webshop/webshop/data/il_trajs_finalized_images.jsonl"


def load_human_trajectories(file_path: str, num_trajs: Optional[int] = None) -> List[Dict]:
    """从 JSONL 文件加载人类演示轨迹"""
    trajectories = []
    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            if num_trajs and i >= num_trajs:
                break
            traj = json.loads(line)
            trajectories.append(traj)
    return trajectories


def extract_instruction(initial_state: str) -> str:
    """从初始状态中提取目标指令"""
    lines = initial_state.split('\n')
    for i, line in enumerate(lines):
        if line.strip() == "Instruction:":
            if i + 1 < len(lines):
                return lines[i + 1].strip()
    raise ValueError("无法从状态中提取目标指令")


def find_matching_goal(env: WebAgentTextEnv, instruction: str) -> Optional[int]:
    """在环境的目标列表中查找匹配的目标索引"""
    for i, goal in enumerate(env.server.goals):
        if goal['instruction_text'] == instruction:
            return i
    return None


class TrajectoryReplayer:
    """在 WebShop 环境中重演人类演示轨迹"""

    def __init__(self, env: WebAgentTextEnv):
        self.env = env

    def replay_trajectory(
        self,
        trajectory: Dict,
        goal_idx: int,
        verbose: bool = False
    ) -> Tuple[List[MilestoneResult], float, bool]:
        """
        重演单条轨迹并收集里程碑结果

        Args:
            trajectory: 轨迹数据（包含 actions 和 states）
            goal_idx: 目标索引
            verbose: 是否打印详细信息

        Returns:
            milestone_results: 每步的里程碑检测结果
            final_reward: 最终奖励
            success: 是否成功完成所有里程碑
        """
        # 重置环境到指定目标
        obs = self.env.reset(session=goal_idx)

        # 获取目标数据
        session_id = self.env.session
        goal_data = self.env.server.user_sessions[session_id]['goal']

        if verbose:
            print(f"  环境目标:")
            print(f"    instruction: {goal_data['instruction_text']}")
            print(f"    attributes: {goal_data.get('attributes', [])}")
            print(f"    options: {goal_data.get('goal_options', [])}")
            print(f"    price_upper: {goal_data.get('price_upper', 'N/A')}")

        # 创建检测器
        detector = MilestoneDetector(
            goal_data,
            self.env.server.product_item_dict,
            self.env.server.product_prices
        )

        # 重演动作序列
        milestone_results = []
        prev_state = obs
        final_reward = 0.0

        for i, action in enumerate(trajectory['actions']):
            if verbose:
                print(f"\n  步骤 {i+1}: {action}")

            # 执行动作
            next_state, reward, done, info = self.env.step(action)

            # 提取详细信息（如果完成）
            if done:
                session_info = self.env.server.user_sessions[session_id]
                verbose_info = session_info.get('verbose_info', {})
                info = {'verbose': verbose_info}
                final_reward = reward

                if verbose:
                    print(f"    环境详细信息: r_att={verbose_info.get('r_att', 0)}, "
                          f"r_price={verbose_info.get('r_price', 0)}, "
                          f"r_option={verbose_info.get('r_option', 0)}")

            # 使用检测器处理
            result = detector.process(action, prev_state, next_state, info)
            milestone_results.append(result)

            if verbose:
                print(f"    阶段: {result.phase.name}, 达成: {result.achieved}")
                print(f"    消息: {result.message}")
                if result.metadata:
                    print(f"    元数据: {result.metadata}")

            prev_state = next_state

            if done:
                break

        # 判断是否成功
        success = (detector.phase == MilestonePhase.COMPLETE)

        return milestone_results, final_reward, success


def print_statistics(results: List[Dict], matched_count: int, total_trajs: int):
    """打印测试统计"""
    print("\n\n" + "=" * 80)
    print("测试统计")
    print("=" * 80)

    print(f"\n总轨迹数: {total_trajs}")
    print(f"匹配的轨迹数: {matched_count}")
    print(f"成功完成的轨迹数: {len([r for r in results if r['success']])}")

    if matched_count > 0:
        print(f"匹配率: {matched_count/total_trajs*100:.1f}%")
        success_count = len([r for r in results if r['success']])
        print(f"成功率（基于匹配的轨迹）: {success_count/matched_count*100:.1f}%")

    # 里程碑达成统计
    milestone_counts = defaultdict(int)

    for result in results:
        for mr in result['milestone_results']:
            if mr.achieved:
                milestone_counts[mr.phase.name] += 1

    print(f"\n里程碑达成统计:")
    for phase in ['SEARCH', 'DETAIL', 'OPTIONS', 'PURCHASE']:
        count = milestone_counts.get(phase, 0)
        percentage = (count / matched_count * 100) if matched_count > 0 else 0
        print(f"  {phase}: {count} 次 ({percentage:.1f}%)")

    # 平均奖励
    if results:
        avg_reward = sum(r['final_reward'] for r in results) / len(results)
        print(f"\n平均最终奖励: {avg_reward:.4f}")

        # 成功轨迹的平均奖励
        success_results = [r for r in results if r['success']]
        if success_results:
            success_avg_reward = sum(r['final_reward'] for r in success_results) / len(success_results)
            print(f"成功轨迹的平均奖励: {success_avg_reward:.4f}")


def main():
    """主测试函数"""
    print("=" * 80)
    print("使用人类真实演示轨迹测试 WebShop 里程碑检测器")
    print("=" * 80)

    # 配置
    NUM_TRAJS = 50  # 测试前50条轨迹

    # 加载轨迹
    print(f"\n加载轨迹文件: {TRAJ_FILE_PATH}")
    trajectories = load_human_trajectories(TRAJ_FILE_PATH, NUM_TRAJS)
    print(f"✓ 加载了 {len(trajectories)} 条轨迹")

    # 初始化环境
    print("\n初始化 WebShop 环境...")
    env = WebAgentTextEnv(
        observation_mode='text',
        num_products=1000,
        human_goals=True,
        seed=42
    )
    print(f"✓ 环境初始化完成")
    print(f"✓ 环境目标数: {len(env.server.goals)}")

    # 创建重演器
    replayer = TrajectoryReplayer(env)

    # 测试每条轨迹
    results = []
    matched_count = 0

    for traj_idx, traj in enumerate(trajectories):
        print(f"\n{'='*80}")
        print(f"测试轨迹 {traj_idx + 1}/{len(trajectories)}")
        print(f"{'='*80}")

        try:
            # 提取目标指令
            instruction = extract_instruction(traj['states'][0])
            print(f"轨迹目标: {instruction}")

            # 查找匹配的目标
            goal_idx = find_matching_goal(env, instruction)

            if goal_idx is None:
                print(f"✗ 未找到匹配的目标，跳过此轨迹")
                continue

            matched_count += 1
            print(f"✓ 找到匹配的目标（索引: {goal_idx}）")

            # 重演轨迹
            milestone_results, final_reward, success = replayer.replay_trajectory(
                traj, goal_idx, verbose=True
            )

            # 收集结果
            results.append({
                'trajectory_idx': traj_idx,
                'goal_idx': goal_idx,
                'instruction': instruction,
                'actions': traj['actions'],
                'milestone_results': milestone_results,
                'final_reward': final_reward,
                'success': success
            })

            print(f"\n结果: {'✓ 成功' if success else '✗ 失败'}")
            print(f"最终奖励: {final_reward:.4f}")

        except Exception as e:
            print(f"✗ 错误: {e}")
            import traceback
            traceback.print_exc()

    # 打印统计
    print_statistics(results, matched_count, len(trajectories))
    print("\n测试完成！\n")


if __name__ == "__main__":
    main()
