#!/usr/bin/env python3
"""
WebShop 里程碑检测器综合测试脚本

该脚本通过运行实际的 WebShop 环境来测试里程碑检测器的正确性。
"""

import sys
import os
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
import json
from collections import defaultdict
# 添加 WebShop 路径
sys.path.insert(0, 'agent_system/environments/env_package/webshop/webshop')
# 直接添加 envs 目录到路径，绕过 __init__.py
sys.path.insert(0, 'agent_system/environments/env_package/webshop/webshop/web_agent_site/envs')

# 直接导入模块文件，避免触发 __init__.py 中的 selenium 依赖
import web_agent_text_env
WebAgentTextEnv = web_agent_text_env.WebAgentTextEnv

from migpo.webshop_milestone_detector import (
    MilestoneDetector,
    MilestonePhase,
    MilestoneResult,
    load_human_goals
)


@dataclass
class TestCase:
    """测试用例定义"""
    name: str
    description: str
    actions: List[str]
    expected_milestones: List[MilestonePhase]
    should_succeed: bool


@dataclass
class TestResult:
    """测试结果"""
    test_name: str
    goal_text: str
    actions: List[str]
    milestone_results: List[MilestoneResult]
    achieved_milestones: List[MilestonePhase]
    final_reward: float
    success: bool
    error_message: str = ""


class WebShopMilestoneDetectorTester:
    """WebShop 里程碑检测器测试器"""

    def __init__(self, num_products: int = 1000, seed: int = 42):
        """初始化测试器"""
        print("=" * 80)
        print("初始化 WebShop 里程碑检测器测试器")
        print("=" * 80)

        # 初始化环境
        print("\n[1/2] 初始化 WebShop 环境...")
        self.env = WebAgentTextEnv(
            observation_mode='text',
            num_products=num_products,
            human_goals=True,
            seed=seed
        )
        print(f"✓ 环境初始化完成 (产品数量: {num_products}, 种子: {seed})")
        print(f"✓ 环境产品数据: {len(self.env.server.product_item_dict)} 个产品")

        # 加载人工目标
        print("\n[2/2] 加载人工目标数据...")
        self.human_goals = load_human_goals()
        print(f"✓ 人工目标加载完成 (目标数: {sum(len(goals) for goals in self.human_goals.values())})")

        self.test_results: List[TestResult] = []
        print("\n" + "=" * 80)
        print("初始化完成！")
        print("=" * 80 + "\n")

    def reset_environment(self, session_idx: Optional[int] = None) -> Tuple[str, Dict[str, Any], MilestoneDetector]:
        """重置环境并创建新的检测器"""
        # 重置环境
        if session_idx is not None:
            obs = self.env.reset(session=session_idx)
        else:
            obs = self.env.reset()

        # 获取当前会话的目标
        session_id = self.env.session
        goal_data = self.env.server.user_sessions[session_id]['goal']

        # 创建检测器 - 使用环境的产品数据（关键修复！）
        detector = MilestoneDetector(
            goal_data,
            self.env.server.product_item_dict,  # 使用环境的产品数据
            self.env.server.product_prices       # 使用环境的价格数据
        )

        return obs, goal_data, detector

    def run_episode(self, actions: List[str], detector: MilestoneDetector,
                   initial_obs: str, verbose: bool = True) -> Tuple[List[MilestoneResult], float]:
        """运行一个回合并收集里程碑结果"""
        milestone_results = []
        prev_state = initial_obs
        final_reward = 0.0

        for i, action in enumerate(actions):
            if verbose:
                print(f"\n  步骤 {i+1}: {action}")

            # 执行动作
            next_state, reward, done, info = self.env.step(action)

            # 添加调试信息：显示搜索结果中的 ASIN
            if verbose and i == 0 and "search" in action.lower():
                # 打印完整观察内容以调试 ASIN 提取问题
                print(f"    完整观察内容（前1000字符）:")
                print(f"    {next_state[:1000]}")
                if len(next_state) > 1000:
                    print(f"    ... (总长度: {len(next_state)} 字符)")

                asins = detector._extract_asins_from_state(next_state)
                print(f"    搜索结果中的 ASIN 数量: {len(asins)}")
                print(f"    前5个 ASIN: {asins[:5]}")

                # 检查哪些 ASIN 符合目标
                satisfying = [asin for asin in asins if detector._product_can_meet_goal(asin)]
                print(f"    符合目标的 ASIN 数量: {len(satisfying)}")
                if satisfying:
                    print(f"    符合目标的 ASIN: {satisfying[:3]}")

            # 如果完成，提取详细信息
            if done:
                session_info = self.env.server.user_sessions[self.env.session]
                verbose_info = session_info.get('verbose_info', {})
                info = {'verbose': verbose_info}
                final_reward = reward

                if verbose and verbose_info:
                    print(f"    环境详细信息: r_att={verbose_info.get('r_att', 0)}, "
                          f"r_price={verbose_info.get('r_price', 0)}, "
                          f"r_option={verbose_info.get('r_option', 0)}")

            # 使用检测器处理
            result = detector.process(action, prev_state, next_state, info)
            milestone_results.append(result)

            if verbose:
                print(f"    阶段: {result.phase.name}, 达成: {result.achieved}")
                print(f"    消息: {result.message}")
                if result.metadata:
                    print(f"    元数据: {result.metadata}")

            prev_state = next_state

            if done:
                break

        return milestone_results, final_reward

    def find_goal_consistent_product(self, goal_data: Dict[str, Any], detector: MilestoneDetector) -> Optional[str]:
        """查找符合目标的产品 ASIN（使用检测器的逻辑）"""
        goal_asin = goal_data.get('asin')

        # 首先检查目标 ASIN 是否在环境的产品列表中
        if goal_asin and goal_asin in self.env.server.product_item_dict:
            # 使用检测器的逻辑验证产品是否符合目标
            if detector._product_can_meet_goal(goal_asin):
                return goal_asin

        # 如果目标 ASIN 不可用，搜索所有产品找到符合条件的
        print(f"  调试: 目标 ASIN {goal_asin} 不在产品列表中或不符合目标，搜索替代产品...")
        for asin in self.env.server.product_item_dict.keys():
            if detector._product_can_meet_goal(asin):
                print(f"  调试: 找到符合目标的产品: {asin}")
                return asin

        print(f"  警告: 未找到任何符合目标的产品！")
        return None

    def create_test_scenarios(self, goal_data: Dict[str, Any], detector: MilestoneDetector) -> List[TestCase]:
        """根据目标创建测试场景"""
        test_cases = []

        # 获取目标信息
        goal_asin = self.find_goal_consistent_product(goal_data, detector)
        goal_attributes = goal_data.get('attributes', [])
        goal_options = goal_data.get('goal_options', [])

        if not goal_asin:
            print(f"  警告: 无法找到符合目标的产品")
            return test_cases

        # 构建搜索查询 - 使用更通用的查询
        if goal_attributes:
            search_query = goal_attributes[0]  # 只使用第一个属性
        else:
            # 从产品标题中提取关键词
            product = self.env.server.product_item_dict.get(goal_asin, {})
            title = product.get('Title', '')
            search_query = ' '.join(title.split()[:2]) if title else 'product'

        # 场景 A: 成功完整路径
        actions_success = [f"search[{search_query}]", f"click[{goal_asin}]"]

        # 添加选项选择
        if goal_options:
            if isinstance(goal_options, dict):
                for option_value in goal_options.values():
                    actions_success.append(f"click[{option_value}]")
            elif isinstance(goal_options, list):
                for option_value in goal_options:
                    actions_success.append(f"click[{option_value}]")

        actions_success.append("click[buy now]")

        test_cases.append(TestCase(
            name="成功完整路径",
            description="搜索 → 点击正确产品 → 选择选项 → 购买",
            actions=actions_success,
            expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                               MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
            should_succeed=True
        ))

        # 场景 B: 错误产品测试（如果有其他产品）
        wrong_asin = None
        for asin in self.env.server.product_item_dict.keys():
            if asin != goal_asin and not detector._product_can_meet_goal(asin):
                wrong_asin = asin
                break

        if wrong_asin:
            actions_wrong_product = [
                f"search[{search_query}]",
                f"click[{wrong_asin}]",
                "click[buy now]"
            ]
            test_cases.append(TestCase(
                name="错误产品测试",
                description="搜索成功但点击不符合目标的产品",
                actions=actions_wrong_product,
                expected_milestones=[MilestonePhase.SEARCH],
                should_succeed=False
            ))

        # 场景 C: 缺少选项测试（仅当有选项要求时）
        if goal_options:
            actions_missing_options = [
                f"search[{search_query}]",
                f"click[{goal_asin}]",
                "click[buy now]"
            ]
            test_cases.append(TestCase(
                name="缺少选项测试",
                description="点击正确产品但跳过选项选择",
                actions=actions_missing_options,
                expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL],
                should_succeed=False
            ))

        # 场景 D: 重复搜索测试
        actions_repeat_search = [
            f"search[{search_query}]",
            "search[product]",  # 重复搜索
            f"click[{goal_asin}]"
        ]
        if goal_options:
            if isinstance(goal_options, dict):
                for option_value in goal_options.values():
                    actions_repeat_search.append(f"click[{option_value}]")
            elif isinstance(goal_options, list):
                for option_value in goal_options:
                    actions_repeat_search.append(f"click[{option_value}]")
        actions_repeat_search.append("click[buy now]")

        test_cases.append(TestCase(
            name="重复搜索测试",
            description="多次搜索后购买",
            actions=actions_repeat_search,
            expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                               MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
            should_succeed=True
        ))

        # 场景 E: 重复点击产品测试（如果有错误产品）
        if wrong_asin:
            actions_repeat_click = [
                f"search[{search_query}]",
                f"click[{wrong_asin}]",  # 先点击错误产品
                f"click[{goal_asin}]"    # 再点击正确产品
            ]
            if goal_options:
                if isinstance(goal_options, dict):
                    for option_value in goal_options.values():
                        actions_repeat_click.append(f"click[{option_value}]")
                elif isinstance(goal_options, list):
                    for option_value in goal_options:
                        actions_repeat_click.append(f"click[{option_value}]")
            actions_repeat_click.append("click[buy now]")

            test_cases.append(TestCase(
                name="重复点击产品测试",
                description="点击多个产品后购买最后一个",
                actions=actions_repeat_click,
                expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                                   MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
                should_succeed=True
            ))

        # 场景 F: 详情页浏览测试
        actions_browse_details = [
            f"search[{search_query}]",
            f"click[{goal_asin}]",
            "click[Description]",  # 浏览详情
            "click[Features]"      # 浏览特性
        ]
        if goal_options:
            if isinstance(goal_options, dict):
                for option_value in goal_options.values():
                    actions_browse_details.append(f"click[{option_value}]")
            elif isinstance(goal_options, list):
                for option_value in goal_options:
                    actions_browse_details.append(f"click[{option_value}]")
        actions_browse_details.append("click[buy now]")

        test_cases.append(TestCase(
            name="详情页浏览测试",
            description="浏览产品详情后购买",
            actions=actions_browse_details,
            expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                               MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
            should_succeed=True
        ))

        # 场景 G: 部分选项测试（如果有多个选项）
        if goal_options and len(goal_options) > 1:
            first_option = None
            if isinstance(goal_options, dict):
                first_option = list(goal_options.values())[0]
            elif isinstance(goal_options, list):
                first_option = goal_options[0]

            if first_option:
                actions_partial_options = [
                    f"search[{search_query}]",
                    f"click[{goal_asin}]",
                    f"click[{first_option}]",  # 只选择第一个选项
                    "click[buy now]"
                ]
                test_cases.append(TestCase(
                    name="部分选项测试",
                    description="只选择部分必需选项",
                    actions=actions_partial_options,
                    expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL],
                    should_succeed=False
                ))

        # 场景 H: 无效产品测试
        actions_invalid_product = [
            f"search[{search_query}]",
            "click[INVALID123]",  # 无效 ASIN
            "click[buy now]"
        ]
        test_cases.append(TestCase(
            name="无效产品测试",
            description="点击不存在的产品",
            actions=actions_invalid_product,
            expected_milestones=[MilestonePhase.SEARCH],
            should_succeed=False
        ))

        # 场景 I: 价格超限测试（查找价格超过上限的产品）
        price_upper = goal_data.get('price_upper', 0)
        if price_upper > 0:
            expensive_asin = None
            for asin, price in self.env.server.product_prices.items():
                if price > price_upper and asin in self.env.server.product_item_dict:
                    # 检查产品是否在搜索结果中可能出现
                    expensive_asin = asin
                    break

            if expensive_asin:
                actions_price_exceed = [
                    f"search[{search_query}]",
                    f"click[{expensive_asin}]",
                    "click[buy now]"
                ]
                test_cases.append(TestCase(
                    name="价格超限测试",
                    description="购买价格超过上限的产品",
                    actions=actions_price_exceed,
                    expected_milestones=[MilestonePhase.SEARCH],
                    should_succeed=False
                ))

        # 场景 J: 返回搜索测试
        if wrong_asin:
            actions_back_to_search = [
                f"search[{search_query}]",
                f"click[{wrong_asin}]",
                "click[< Prev]",  # 返回搜索
                f"click[{goal_asin}]"
            ]
            if goal_options:
                if isinstance(goal_options, dict):
                    for option_value in goal_options.values():
                        actions_back_to_search.append(f"click[{option_value}]")
                elif isinstance(goal_options, list):
                    for option_value in goal_options:
                        actions_back_to_search.append(f"click[{option_value}]")
            actions_back_to_search.append("click[buy now]")

            test_cases.append(TestCase(
                name="返回搜索测试",
                description="浏览产品后返回搜索再选择",
                actions=actions_back_to_search,
                expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                                   MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
                should_succeed=True
            ))

        # 场景 K: 翻页测试
        actions_next_page = [
            f"search[{search_query}]",
            "click[Next >]",  # 翻页
            f"click[{goal_asin}]"
        ]
        if goal_options:
            if isinstance(goal_options, dict):
                for option_value in goal_options.values():
                    actions_next_page.append(f"click[{option_value}]")
            elif isinstance(goal_options, list):
                for option_value in goal_options:
                    actions_next_page.append(f"click[{option_value}]")
        actions_next_page.append("click[buy now]")

        test_cases.append(TestCase(
            name="翻页测试",
            description="在搜索结果中翻页后选择产品",
            actions=actions_next_page,
            expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                               MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
            should_succeed=True
        ))

        # 场景 L: 重复选择选项测试（如果有选项）
        if goal_options:
            actions_repeat_options = [
                f"search[{search_query}]",
                f"click[{goal_asin}]"
            ]

            # 添加选项，然后重复选择
            if isinstance(goal_options, dict):
                option_values = list(goal_options.values())
                for option_value in option_values:
                    actions_repeat_options.append(f"click[{option_value}]")
                # 重复选择第一个选项
                if option_values:
                    actions_repeat_options.append(f"click[{option_values[0]}]")
            elif isinstance(goal_options, list):
                for option_value in goal_options:
                    actions_repeat_options.append(f"click[{option_value}]")
                # 重复选择第一个选项
                if goal_options:
                    actions_repeat_options.append(f"click[{goal_options[0]}]")

            actions_repeat_options.append("click[buy now]")

            test_cases.append(TestCase(
                name="重复选择选项测试",
                description="多次选择同一选项后购买",
                actions=actions_repeat_options,
                expected_milestones=[MilestonePhase.SEARCH, MilestonePhase.DETAIL,
                                   MilestonePhase.OPTIONS, MilestonePhase.PURCHASE],
                should_succeed=True
            ))

        return test_cases

    def run_test(self, test_case: TestCase, session_idx: Optional[int] = None) -> TestResult:
        """运行单个测试"""
        print(f"\n{'='*80}")
        print(f"测试: {test_case.name}")
        print(f"描述: {test_case.description}")
        print(f"{'='*80}")

        # 重置环境
        obs, goal_data, detector = self.reset_environment(session_idx)

        goal_text = goal_data.get('instruction_text', 'N/A')
        print(f"\n目标: {goal_text}")
        print(f"目标属性: {goal_data.get('attributes', [])}")
        print(f"目标选项: {goal_data.get('goal_options', [])}")
        print(f"价格上限: {goal_data.get('price_upper', 'N/A')}")

        print(f"\n动作序列:")
        for i, action in enumerate(test_case.actions, 1):
            print(f"  {i}. {action}")

        print(f"\n执行测试:")

        # 运行回合
        try:
            milestone_results, final_reward = self.run_episode(
                test_case.actions, detector, obs, verbose=True
            )

            # 收集达成的里程碑
            achieved_milestones = [
                result.phase for result in milestone_results if result.achieved
            ]

            # 判断成功
            success = (detector.phase == MilestonePhase.COMPLETE) if test_case.should_succeed else True

            print(f"\n{'='*80}")
            print(f"测试结果:")
            print(f"  达成的里程碑: {[m.name for m in achieved_milestones]}")
            print(f"  最终阶段: {detector.phase.name}")
            print(f"  环境奖励: {final_reward:.4f}")
            print(f"  测试状态: {'✓ 通过' if success else '✗ 失败'}")
            print(f"{'='*80}\n")

            return TestResult(
                test_name=test_case.name,
                goal_text=goal_text,
                actions=test_case.actions,
                milestone_results=milestone_results,
                achieved_milestones=achieved_milestones,
                final_reward=final_reward,
                success=success
            )

        except Exception as e:
            print(f"\n✗ 测试执行出错: {str(e)}")
            import traceback
            traceback.print_exc()

            return TestResult(
                test_name=test_case.name,
                goal_text=goal_text,
                actions=test_case.actions,
                milestone_results=[],
                achieved_milestones=[],
                final_reward=0.0,
                success=False,
                error_message=str(e)
            )

    def run_all_tests(self, num_goals: int = 3):
        """运行所有测试"""
        print("\n" + "=" * 80)
        print(f"开始运行测试 (测试 {num_goals} 个目标)")
        print("=" * 80)

        # 检查环境可用的目标数量
        available_goals = len(self.env.server.goals)
        print(f"\n环境可用目标数: {available_goals}")

        if available_goals < num_goals:
            print(f"警告: 环境只有 {available_goals} 个目标，将测试所有可用目标")
            num_goals = available_goals

        # 选择几个目标进行测试
        goal_asins = list(self.human_goals.keys())[:num_goals]

        for goal_idx, goal_asin in enumerate(goal_asins):
            print(f"\n\n{'#'*80}")
            print(f"# 目标 {goal_idx + 1}/{num_goals}: ASIN {goal_asin}")
            print(f"{'#'*80}")

            try:
                # 重置环境到特定目标
                obs, goal_data, detector = self.reset_environment(session_idx=goal_idx)
            except IndexError as e:
                print(f"错误: 无法访问目标 {goal_idx}: {e}")
                print(f"跳过此目标")
                continue

            # 创建测试场景
            test_cases = self.create_test_scenarios(goal_data, detector)

            if not test_cases:
                print("  跳过此目标（无法创建测试场景）")
                continue

            # 运行每个测试场景
            for test_case in test_cases:
                result = self.run_test(test_case, session_idx=goal_idx)
                self.test_results.append(result)

        # 打印汇总
        self.print_summary()
        self.print_detailed_summary()

    def print_summary(self):
        """打印测试汇总"""
        print("\n\n" + "=" * 80)
        print("测试汇总")
        print("=" * 80)

        total_tests = len(self.test_results)
        passed_tests = sum(1 for r in self.test_results if r.success)
        failed_tests = total_tests - passed_tests

        print(f"\n总测试数: {total_tests}")
        print(f"通过: {passed_tests}")
        print(f"失败: {failed_tests}")
        print(f"通过率: {passed_tests/total_tests*100:.1f}%" if total_tests > 0 else "N/A")

        if failed_tests > 0:
            print(f"\n失败的测试:")
            for result in self.test_results:
                if not result.success:
                    print(f"  - {result.test_name}")
                    if result.error_message:
                        print(f"    错误: {result.error_message}")

        print("\n" + "=" * 80)

    def print_detailed_summary(self):
        """打印详细的测试统计"""
        print("\n\n" + "=" * 80)
        print("详细测试统计")
        print("=" * 80)

        # 按测试类型分组
        by_type = defaultdict(list)
        for result in self.test_results:
            # 提取测试类型（去掉"测试"后缀）
            test_type = result.test_name.replace("测试", "")
            by_type[test_type].append(result)

        # 打印每种类型的统计
        print("\n按测试类型分组:")
        for test_type, results in sorted(by_type.items()):
            total = len(results)
            passed = sum(1 for r in results if r.success)
            failed = total - passed
            pass_rate = (passed / total * 100) if total > 0 else 0
            print(f"\n  {test_type}:")
            print(f"    总数: {total}, 通过: {passed}, 失败: {failed}, 通过率: {pass_rate:.1f}%")

        # 里程碑达成统计
        milestone_counts = defaultdict(int)
        for result in self.test_results:
            for milestone in result.achieved_milestones:
                milestone_counts[milestone.name] += 1

        print(f"\n里程碑达成统计:")
        total_episodes = len(self.test_results)
        for milestone, count in sorted(milestone_counts.items()):
            percentage = (count / total_episodes * 100) if total_episodes > 0 else 0
            print(f"  {milestone}: {count} 次 ({percentage:.1f}%)")

        # 平均奖励统计
        total_reward = sum(r.final_reward for r in self.test_results)
        avg_reward = total_reward / len(self.test_results) if self.test_results else 0
        print(f"\n平均环境奖励: {avg_reward:.4f}")

        # 成功测试的平均奖励
        success_results = [r for r in self.test_results if r.success]
        if success_results:
            success_avg_reward = sum(r.final_reward for r in success_results) / len(success_results)
            print(f"成功测试的平均奖励: {success_avg_reward:.4f}")

        # 失败测试的平均奖励
        failed_results = [r for r in self.test_results if not r.success]
        if failed_results:
            failed_avg_reward = sum(r.final_reward for r in failed_results) / len(failed_results)
            print(f"失败测试的平均奖励: {failed_avg_reward:.4f}")

        print("\n" + "=" * 80)


def main():
    """主函数"""
    print("\n" + "=" * 80)
    print("WebShop 里程碑检测器测试")
    print("=" * 80 + "\n")

    # 创建测试器（使用1000个产品以支持更多测试场景）
    tester = WebShopMilestoneDetectorTester(num_products=1000, seed=42)

    # 运行所有测试（测试5个目标以覆盖更多场景）
    tester.run_all_tests(num_goals=5)

    print("\n测试完成！\n")


if __name__ == "__main__":
    main()
