#!/usr/bin/env python
"""
Step 2: Harmful Prompt Generation Test
======================================

Test generating harmful prompts using real knowledge graph data and fine-tuned models:
- Check vLLM service availability
- Extract real nodes from knowledge graph built in Step 1
- Generate harmful prompts combining JBB dataset
- Validate generation results format and quality
- Save generated prompt data

Notes: 
1. This step requires starting the fine-tuned model server (port 8000)
2. Need to run Step 1 first to build knowledge graph
3. Need JBB-Behaviors dataset

Usage:
    python step2_harmful_generation.py --domain medicine
    python step2_harmful_generation.py --domain medicine --selected-nodes-count 5 --verbose
    python step2_harmful_generation.py --domain medicine --check-service-only

"""

import argparse
import sys
import os
import time
import json
import requests
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

# Add project root directory to Python path
SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.utils.logger_utils import setup_logger
from src.modules.harmful_prompt_generator import HarmfulPromptGenerator, HarmCategory
from src.models.llm_factory import LLMManager
from src.utils.config_manager import ConfigManager
from src.services.knowledge_graph import NodeService
from src.utils.neo4j_utils import create_neo4j_manager
from src.utils.random_seed_utils import load_seed_from_config

class HarmfulGenerationTester:
    """Harmful prompt generation tester - using real knowledge graph data"""
    
    def __init__(self, domain: str = "medicine", verbose: bool = False, selected_nodes_count: int = 3):
        self.domain = domain
        self.verbose = verbose
        self.selected_nodes_count = selected_nodes_count
        
        # Set up logging - write detailed logs to file, console shows only key information
        from datetime import datetime
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_dir = PROJECT_ROOT / "logs" / "step2"
        log_file = f"step2_harmful_generation_{domain}_{timestamp}.log"
        
        # Create two loggers: one for file (detailed), one for console (concise)
        self.file_logger = setup_logger(
            name=f"Step2_File_{domain}",
            log_file=log_file,
            level="DEBUG",
            log_dir=str(log_dir),
            enable_file_logging=True
        )
        
        self.console_logger = setup_logger(
            name=f"Step2_Console_{domain}",
            level="INFO",
            enable_file_logging=False
        )
        
        # Mainly use file_logger, key information also output to console
        self.logger = self.file_logger
        
        # Initialize configuration manager
        self.config_manager = ConfigManager()
        
        # Set random seed to ensure experimental reproducibility
        actual_seed = load_seed_from_config(self.config_manager)
        self.logger.info(f"✅ Random seed has been set to: {actual_seed}")
        
        # Set third-party module log levels to reduce console output
        import logging
        # Set specific module log level to WARNING, so only WARNING and ERROR will show in console
        logging.getLogger("src.modules.harmful_prompt_generator").setLevel(logging.WARNING)
        logging.getLogger("src.services").setLevel(logging.WARNING)
        logging.getLogger("src.models").setLevel(logging.WARNING)
        logging.getLogger("src.utils").setLevel(logging.WARNING)
        # Set HTTP related library log levels
        logging.getLogger("urllib3").setLevel(logging.WARNING)
        logging.getLogger("requests").setLevel(logging.WARNING)
        logging.getLogger("httpx").setLevel(logging.WARNING)
        
        # Model service configuration
        self.model_config = {
            "base_url": "http://localhost:8000/v1",
            "model_name": "llama3.1-70b-finetune",
        }
        
        self.test_results = {}
        self.harmful_generator = None
        self.kg_service = None
        self.real_nodes = []
        self.prompt_templates = []  # Store prompt templates for debugging
        self.domain_config = None  # Store domain configuration
    
    def log_progress(self, message: str, level: str = "info"):
        """Important progress information logged to both file and console"""
        log_func = getattr(self.file_logger, level)
        console_log_func = getattr(self.console_logger, level)
        
        log_func(message)  # Detailed logs written to file
        console_log_func(message)  # Key information displayed in console
    
    def log_detail(self, message: str, level: str = "debug"):
        """Detailed information logged to file only"""
        log_func = getattr(self.file_logger, level)
        log_func(message)
    
    def check_model_service(self) -> bool:
        """Check fine-tuned model service availability"""
        self.log_progress("🤖 Checking fine-tuned model service...")
        
        try:
            # 从配置获取超时时间  
            timeout = 10  # 默认值
            if self.domain_config and hasattr(self.domain_config, 'test_params'):
                timeout = self.domain_config.test_params.get('request_timeout', 10)
            
            # 检查服务健康状态
            health_url = f"{self.model_config['base_url']}/models"
            response = requests.get(health_url, timeout=timeout)
            
            if response.status_code == 200:
                models_data = response.json()
                available_models = [model["id"] for model in models_data.get("data", [])]
                
                if self.model_config["model_name"] in available_models:
                    self.log_progress(f"  ✅ Model service available: {self.model_config['model_name']}")
                    self.log_detail(f"  📊 Available models: {available_models}")
                    
                    self.test_results["model_service"] = {
                        "status": True,
                        "available_models": available_models,
                        "target_model": self.model_config["model_name"]
                    }
                    return True
                else:
                    self.log_progress(f"  ❌ Target model not available: {self.model_config['model_name']}", "error")
                    self.log_detail(f"  📊 Available models: {available_models}")
                    return False
            else:
                self.log_progress(f"  ❌ Model service not available: HTTP {response.status_code}", "error")
                return False
                
        except requests.RequestException as e:
            self.log_progress(f"  ❌ Model service connection failed: {e}", "error")
            self.log_progress(f"  💡 Please make sure to run: bash scripts/start_vllm_finetune_server.sh")
            self.test_results["model_service"] = {"status": False, "error": str(e)}
            return False
    
    def test_model_initialization(self) -> bool:
        """Test model initialization"""
        self.log_progress("🔧 Test model initialization...")
        
        try:
            # 使用LLM管理器创建模型实例
            llm_manager = LLMManager()
            llm_manager.switch_model("llama3.1-70b-finetune")
            
            # 从配置文件获取测试参数
            test_config = self.config_manager.main_config.get('testing', {}).get('model_initialization', {})
            test_prompt = test_config.get('test_prompt', "Hello, this is a test prompt. Please respond briefly.")
            max_tokens = test_config.get('max_tokens', 50)
            
            # 测试简单的文本生成（使用配置文件中的参数）
            response = llm_manager.generate(test_prompt, max_tokens=max_tokens)
            
            if response and len(response.strip()) > 0:
                self.log_progress("  ✅ Model initialization successful")
                self.log_detail(f"  📝 Test response: {response[:100]}...")
                
                self.test_results["model_initialization"] = {
                    "status": True,
                    "test_response_length": len(response),
                    "test_max_tokens": max_tokens
                }
                return True
            else:
                self.log_progress("  ❌ Model generated empty response", "error")
                return False
                
        except Exception as e:
            self.log_progress(f"  ❌ Model initialization failed: {e}", "error")
            self.test_results["model_initialization"] = {"status": False, "error": str(e)}
            return False
    
    def test_harmful_prompt_generator(self) -> bool:
        """测试有害提示生成器"""
        self.log_progress("⚡ 测试有害提示生成器初始化...")
        
        try:
            # 初始化LLM管理器
            llm_manager = LLMManager()
            
            # 初始化有害提示生成器
            self.harmful_generator = HarmfulPromptGenerator(
                llm_manager=llm_manager,
                generation_model="llama3.1-70b-finetune"
            )
            
            self.log_progress("  ✅ 有害提示生成器初始化成功")
            self.test_results["generator_initialization"] = {"status": True}
            return True
            
        except Exception as e:
            self.log_progress(f"  ❌ 有害提示生成器初始化失败: {e}", "error")
            self.test_results["generator_initialization"] = {"status": False, "error": str(e)}
            return False
    
    def load_real_knowledge_graph_nodes(self) -> bool:
        """从真实知识图谱中加载节点数据"""
        self.log_progress("📊 加载真实知识图谱节点...")
        
        try:
            # 加载领域配置
            self.domain_config = self.config_manager.load_domain_config(self.domain)
            self.log_detail(f"  📋 领域配置: {self.domain}")
            
            # 从配置中获取每个类别的提示数量
            prompts_per_category = 2  # 默认值
            if self.domain_config and hasattr(self.domain_config, 'harm_categories'):
                prompts_per_category = self.domain_config.harm_categories.get('prompts_per_category', 2)
            self.log_detail(f"  📝 每个类别生成提示数量: {prompts_per_category}")
            
            # 初始化知识图谱服务
            neo4j_config = self.config_manager.main_config.get('database', {}).get('neo4j', {})
            neo4j_manager = create_neo4j_manager({
                'uri': neo4j_config.get('uri'),
                'user': neo4j_config.get('user'),
                'password': neo4j_config.get('password')
            })
            neo4j_manager.connect()  # 建立连接
            self.kg_service = NodeService(neo4j_manager)
            
            # 使用固定的节点数量限制
            kg_node_limit = 1000
            
            # 获取知识图谱中的所有节点
            all_nodes = self.kg_service.get_all_nodes(limit=kg_node_limit)  # 获取更多节点用于随机选择
            
            if not all_nodes:
                self.log_progress("  ❌ 知识图谱中没有找到节点，请先运行Step 1构建知识图谱", "error")
                return False
            
            # 筛选有描述信息的节点
            filtered_nodes = []
            for node in all_nodes:
                if hasattr(node, 'properties') and node.properties:
                    # 检查是否有描述信息
                    description = (
                        node.properties.get('description') or 
                        node.properties.get('wikipedia_description') or
                        node.properties.get('summary') or
                        node.properties.get('wikipedia_summary')
                    )
                    if description and len(description.strip()) > 20:
                        # 获取wikidata_uri，优先使用uri属性
                        wikidata_uri = (
                            node.properties.get('uri') or  # 首选：Neo4j中的uri属性
                            node.properties.get('wikidata_id') or
                            node.properties.get('id') or 
                            f"http://www.wikidata.org/entity/Q{node.id}"  # 如果都没有，使用内部ID生成
                        )
                        
                        filtered_nodes.append({
                            'id': node.id,
                            'name': node.name,
                            'labels': node.labels,
                            'properties': node.properties,
                            'description': description.strip(),
                            'wikidata_uri': wikidata_uri
                        })
            
            # 使用传入的节点选择数量
            selected_nodes_count = self.selected_nodes_count
            
            # 根据参数决定使用的节点数量
            import random
            if selected_nodes_count == -1:
                # 使用所有过滤后的节点
                self.real_nodes = filtered_nodes
                self.log_progress(f"  📊 使用所有可用节点: {len(filtered_nodes)} 个")
            elif len(filtered_nodes) >= selected_nodes_count:
                # 随机选择指定数量的节点
                self.real_nodes = random.sample(filtered_nodes, selected_nodes_count)
                self.log_progress(f"  🎲 随机选择了 {selected_nodes_count} 个节点")
            else:
                # 如果不够就全部使用
                self.real_nodes = filtered_nodes
                self.log_progress(f"  ⚠️  可用节点不足，使用全部 {len(filtered_nodes)} 个节点")
            
            self.log_progress(f"  ✅ 最终使用 {len(self.real_nodes)} 个真实节点")
            
            # 只在节点数量较少时显示详细信息到控制台，否则仅记录到文件
            if len(self.real_nodes) <= 5:
                for i, node in enumerate(self.real_nodes):
                    self.log_detail(f"    {i+1}. {node['name']}: {node['description'][:100]}...")
            else:
                self.log_detail(f"    节点数量较多({len(self.real_nodes)}个)，详细信息见日志文件")
                for i, node in enumerate(self.real_nodes[:3]):
                    self.log_detail(f"    样例 {i+1}. {node['name']}: {node['description'][:100]}...")
            
            self.test_results["knowledge_graph_loading"] = {
                "status": True,
                "total_nodes_found": len(all_nodes),
                "filtered_nodes": len(filtered_nodes),
                "selected_nodes": len(self.real_nodes)
            }
            return True
            
        except Exception as e:
            self.log_progress(f"  ❌ 加载知识图谱节点失败: {e}", "error")
            self.test_results["knowledge_graph_loading"] = {"status": False, "error": str(e)}
            return False

    def test_context_formatting(self) -> bool:
        """测试真实节点上下文格式化"""
        self.log_progress("📋 测试真实节点上下文格式化...")
        
        if not self.real_nodes:
            self.log_progress("  ❌ 没有可用的真实节点数据", "error")
            return False
        
        try:
            formatted_contexts = []
            domain_info = f"{self.domain} domain"
            
            for node in self.real_nodes:
                # 提取Wikipedia信息
                wikipedia_description = (
                    node['properties'].get('wikipedia_description') or 
                    node['properties'].get('description')
                )
                wikipedia_summary = (
                    node['properties'].get('wikipedia_summary') or 
                    node['properties'].get('summary')
                )
                
                formatted_context = {
                    "node_id": node['id'],
                    "node_name": node['name'],
                    "labels": node['labels'],
                    "wikipedia_description": wikipedia_description,
                    "wikipedia_summary": wikipedia_summary,
                    "domain_info": domain_info,
                    "raw_properties": node['properties']
                }
                
                formatted_contexts.append(formatted_context)
                
                # 只记录详细信息到文件
                self.log_detail(f"  📄 节点: {node['name']}")
                self.log_detail(f"      ID: {node['id']}")
                self.log_detail(f"      标签: {node['labels']}")
                if wikipedia_description:
                    self.log_detail(f"      描述长度: {len(wikipedia_description)}")
            
            self.log_progress(f"  ✅ 成功格式化 {len(formatted_contexts)} 个真实节点上下文")
            self.test_results["context_formatting"] = {
                "status": True,
                "formatted_count": len(formatted_contexts),
                "contexts": formatted_contexts
            }
            return True
            
        except Exception as e:
            self.log_progress(f"  ❌ 真实节点上下文格式化失败: {e}", "error")
            self.test_results["context_formatting"] = {"status": False, "error": str(e)}
            return False
    
    def test_harmful_prompt_generation(self) -> bool:
        """测试有害提示生成（使用真实知识图谱节点）"""
        self.log_progress("🎯 测试有害提示生成...")
        
        if not self.harmful_generator:
            self.log_progress("  ❌ 有害提示生成器未初始化", "error")
            return False
        
        if not self.real_nodes:
            self.log_progress("  ❌ 没有可用的真实节点数据", "error")
            return False
        
        try:
            generated_prompts = []
            domain_info = f"{self.domain} domain"
            
            # 从配置中获取每个类别的提示数量
            prompts_per_category = 2  # 默认值
            if self.domain_config and hasattr(self.domain_config, 'harm_categories'):
                prompts_per_category = self.domain_config.harm_categories.get('prompts_per_category', 2)
            self.log_detail(f"  📝 每个类别生成提示数量: {prompts_per_category}")
            
            # 测试所有可用的JBB类别生成提示
            test_categories = list(HarmCategory)
            self.log_progress(f"  📋 将测试 {len(test_categories)} 个有害内容类别")
            self.log_detail(f"  📋 类别列表: {[cat.value for cat in test_categories]}")
            
            total_expected = len(self.real_nodes) * len(test_categories) * prompts_per_category
            self.log_progress(f"  🎯 预计生成 {total_expected} 个提示 ({len(self.real_nodes)} 节点 × {len(test_categories)} 类别 × {prompts_per_category} 提示/类别)")
            
            processed_nodes = 0
            for i, node in enumerate(self.real_nodes):  # 使用所有选中的节点
                node_name = node['name']
                wikipedia_description = (
                    node['properties'].get('wikipedia_description') or 
                    node['properties'].get('description')
                )
                wikipedia_summary = (
                    node['properties'].get('wikipedia_summary') or 
                    node['properties'].get('summary')
                )
                
                # 只显示关键进度，不显示每个节点的详细信息
                if i % 10 == 0 or i == len(self.real_nodes) - 1:
                    self.log_progress(f"  🔍 处理节点 {i+1}/{len(self.real_nodes)}: {node_name}")
                else:
                    self.log_detail(f"  🔍 处理节点 {i+1}/{len(self.real_nodes)}: {node_name}")
                
                for category in test_categories:
                    try:
                        self.log_detail(f"    📝 生成 {category.value} 类型提示...")
                        
                        # 先构建prompt模板用于调试
                        generation_prompt = self.harmful_generator.build_generation_prompt(
                            node_name=node_name,
                            category=category,
                            num_prompts=prompts_per_category,
                            wikipedia_description=wikipedia_description,
                            wikipedia_summary=wikipedia_summary,
                            domain_info=domain_info,
                        )
                        
                        # 保存prompt模板
                        self.prompt_templates.append({
                            "node_id": node['id'],
                            "node_name": node_name,
                            "category": category.value,
                            "wikipedia_description": wikipedia_description[:200] if wikipedia_description else None,
                            "wikipedia_summary": wikipedia_summary[:200] if wikipedia_summary else None,
                            "domain_info": domain_info,
                            "prompt_template": generation_prompt,
                            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                        })
                        
                        # 生成有害提示 - 关闭 verbose 减少输出
                        harmful_prompts = self.harmful_generator.generate_prompts_for_category(
                            node_name=node_name,
                            category=category,
                            num_prompts=prompts_per_category,  # 使用配置中的数量
                            wikipedia_description=wikipedia_description,
                            wikipedia_summary=wikipedia_summary,
                            domain_info=domain_info,
                            verbose=False,  # 关闭详细输出
                            debug_prompt=False  # 关闭debug输出
                        )
                        
                        for j, prompt in enumerate(harmful_prompts):
                            # 从JBB数据集获取具体的行为类型
                            behavior_type = self._get_behavior_type_for_category(category)
                            
                            generated_prompts.append({
                                "node_id": node['id'],
                                "node_name": node_name,
                                "wikidata_uri": node.get('wikidata_uri'),
                                "category": category.value,
                                "behavior_type": behavior_type,  # 新增：具体的行为类型
                                "prompt_id": f"{node['id']}_{category.value}_{j+1}",
                                "prompt_text": prompt,
                                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                                "wikipedia_description": wikipedia_description[:200] if wikipedia_description else None,
                                "domain": self.domain
                            })
                            
                            # 只记录到文件，不显示在控制台
                            self.log_detail(f"      ✓ 提示 {j+1}: {prompt[:100]}...")
                        
                        time.sleep(0.5)  # 减少间隔以加快速度
                        
                    except Exception as e:
                        self.log_detail(f"    ⚠️  类别 {category.value} 生成失败: {e}")
                        continue
                
                processed_nodes += 1
                # 每10个节点或最后一个节点报告进度
                if processed_nodes % 10 == 0 or processed_nodes == len(self.real_nodes):
                    current_prompts = len(generated_prompts)
                    self.log_progress(f"  📊 已处理 {processed_nodes}/{len(self.real_nodes)} 节点，生成 {current_prompts} 个提示")
                
                time.sleep(1)  # 节点间较短间隔
            
            if generated_prompts:
                self.log_progress(f"  ✅ 成功生成 {len(generated_prompts)} 个有害提示")
                
                # 按类别统计
                category_counts = {}
                for prompt in generated_prompts:
                    category = prompt['category']
                    category_counts[category] = category_counts.get(category, 0) + 1
                
                self.log_detail("  📊 按类别统计:")
                for category, count in category_counts.items():
                    self.log_detail(f"    {category}: {count} 个")
                
                self.test_results["harmful_generation"] = {
                    "status": True,
                    "generated_count": len(generated_prompts),
                    "category_counts": category_counts,
                    "prompts": generated_prompts
                }
                return True
            else:
                self.log_progress("  ❌ 未生成任何有害提示", "error")
                self.test_results["harmful_generation"] = {"status": False, "generated_count": 0}
                return False
                
        except Exception as e:
            self.log_progress(f"  ❌ 有害提示生成测试失败: {e}", "error")
            self.test_results["harmful_generation"] = {"status": False, "error": str(e)}
            return False
    
    def _get_behavior_type_for_category(self, category) -> str:
        """从JBB数据集中获取具体的行为类型"""
        try:
            if hasattr(self.harmful_generator, 'jbb_df') and self.harmful_generator.jbb_df is not None:
                # 筛选同类别的数据并获取行为类型
                category_data = self.harmful_generator.jbb_df[self.harmful_generator.jbb_df['Category'] == category.value]
                if len(category_data) > 0:
                    # 获取最常见的行为类型
                    behavior_counts = category_data['Behavior'].value_counts()
                    if len(behavior_counts) > 0:
                        return behavior_counts.index[0]  # 返回最常见的行为类型
            return category.value  # 如果无法获取，返回类别名称
        except Exception as e:
            self.logger.warning(f"获取行为类型失败: {e}")
            return category.value

    def test_output_validation(self) -> bool:
        """测试输出验证"""
        self.log_progress("✅ 测试生成结果验证...")
        
        if "harmful_generation" not in self.test_results or not self.test_results["harmful_generation"]["status"]:
            self.log_progress("  ❌ 没有可验证的生成结果", "error")
            return False
        
        try:
            prompts = self.test_results["harmful_generation"]["prompts"]
            validation_results = {
                "total_prompts": len(prompts),
                "valid_prompts": 0,
                "empty_prompts": 0,
                "issues": []
            }
            
            for prompt_data in prompts:
                prompt_text = prompt_data["prompt_text"]
                
                # 验证标准
                if not prompt_text or len(prompt_text.strip()) == 0:
                    validation_results["empty_prompts"] += 1
                    validation_results["issues"].append(f"空提示: {prompt_data['prompt_id']}")
                else:
                    validation_results["valid_prompts"] += 1
            
            # 计算验证率
            valid_rate = validation_results["valid_prompts"] / validation_results["total_prompts"]
            
            self.log_progress(f"  📊 验证统计:")
            self.log_progress(f"      总提示数: {validation_results['total_prompts']}")
            self.log_progress(f"      有效提示: {validation_results['valid_prompts']}")
            self.log_progress(f"      验证率: {valid_rate:.1%}")
            
            if validation_results["issues"]:
                self.log_detail(f"  ⚠️  发现问题:")
                for issue in validation_results["issues"][:5]:  # 只显示前5个问题
                    self.log_detail(f"      {issue}")
            
            # 简单判断是否有有效结果
            success = valid_rate > 0  # 只要有生成结果就算成功
            if success:
                self.log_progress("  ✅ 输出验证完成")
            else:
                self.log_progress("  ❌ 输出验证失败：没有有效的生成结果", "error")
            
            self.test_results["output_validation"] = {
                "status": success,
                "validation_results": validation_results,
                "valid_rate": valid_rate
            }
            return success
            
        except Exception as e:
            self.log_progress(f"  ❌ 输出验证失败: {e}", "error")
            self.test_results["output_validation"] = {"status": False, "error": str(e)}
            return False
    
    def run_all_tests(self, check_service_only: bool = False) -> bool:
        """运行所有有害提示生成测试"""
        nodes_desc = "所有节点" if self.selected_nodes_count == -1 else f"{self.selected_nodes_count}个节点"
        self.log_progress(f"🎯 开始有害提示生成测试 (领域: {self.domain}, 节点策略: {nodes_desc})")
        self.log_progress("="*60)
        
        tests = [
            ("模型服务检查", self.check_model_service),
        ]
        
        if not check_service_only:
            tests.extend([
                ("模型初始化", self.test_model_initialization),
                ("生成器初始化", self.test_harmful_prompt_generator),
                ("知识图谱加载", self.load_real_knowledge_graph_nodes),
                ("上下文格式化", self.test_context_formatting),
                ("有害提示生成", self.test_harmful_prompt_generation),
                ("输出验证", self.test_output_validation)
            ])
        
        passed_tests = 0
        total_tests = len(tests)
        
        for test_name, test_func in tests:
            self.log_progress(f"\n📋 执行测试: {test_name}")
            try:
                success = test_func()
                if success:
                    passed_tests += 1
                elif test_name == "模型服务检查":
                    # 如果Model service not available，提前退出
                    self.log_progress("💔 Model service not available，无法继续测试", "error")
                    break
            except Exception as e:
                self.log_progress(f"测试 '{test_name}' 时发生异常: {e}", "error")
        
        # 生成测试总结
        self.log_progress(f"\n{'='*60}")
        self.log_progress(f"📊 有害提示生成测试总结")
        self.log_progress(f"✅ 通过: {passed_tests}/{total_tests}")
        
        if check_service_only:
            overall_status = passed_tests >= 1
            if overall_status:
                self.log_progress("🎉 模型服务检查通过，可以运行完整测试")
            else:
                self.log_progress("❌ Model service not available，请启动服务后重试", "error")
        else:
            success_rate = passed_tests / total_tests
            # 从配置获取成功率阈值
            overall_status = success_rate > 0  # 只要有任何测试通过就算成功
            
            if overall_status:
                self.log_progress("🎉 有害提示生成测试完成")
            else:
                self.log_progress("❌ 有害提示生成测试失败", "error")
        
        # 保存测试结果
        self._save_test_results(overall_status, passed_tests, total_tests)
        
        return overall_status
    
    def _save_test_results(self, overall_status: bool, passed: int, total: int):
        """保存测试结果到文件"""
        report = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "domain": self.domain,
            "selected_nodes_count": self.selected_nodes_count,
            "model_config": self.model_config,
            "overall_status": overall_status,
            "summary": {
                "passed_tests": passed,
                "total_tests": total,
                "success_rate": passed / total if total > 0 else 0
            },
            "detailed_results": self.test_results
        }
        
        output_dir = SCRIPT_DIR / "outputs" / self.domain
        output_dir.mkdir(parents=True, exist_ok=True)
        
        report_path = output_dir / "step2_harmful_generation_report.json"
        with open(report_path, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=2, ensure_ascii=False)
        
        self.log_progress(f"📋 有害提示生成测试报告已保存: {report_path}")
        
        # 如果有生成的提示，单独保存
        if ("harmful_generation" in self.test_results and 
            self.test_results["harmful_generation"]["status"]):
            prompts_path = output_dir / "step2_generated_prompts.json"
            with open(prompts_path, 'w', encoding='utf-8') as f:
                json.dump(
                    self.test_results["harmful_generation"]["prompts"], 
                    f, indent=2, ensure_ascii=False
                )
            self.log_progress(f"📝 生成的提示已保存: {prompts_path}")
        
        # 保存prompt模板用于调试
        if self.prompt_templates:
            templates_path = output_dir / "step2_prompt_templates.json"
            with open(templates_path, 'w', encoding='utf-8') as f:
                json.dump(self.prompt_templates, f, indent=2, ensure_ascii=False)
            self.log_detail(f"🔍 Prompt templates saved: {templates_path}")
            
            # 另外保存一个更易读的文本版本
            templates_txt_path = output_dir / "step2_prompt_templates.txt"
            with open(templates_txt_path, 'w', encoding='utf-8') as f:
                f.write("Step 2 Harmful Prompt Generation - Template Debug Information\n")
                f.write("="*80 + "\n\n")
                
                for i, template_data in enumerate(self.prompt_templates):
                    f.write(f"Template {i+1}: {template_data['node_name']} - {template_data['category']}\n")
                    f.write("-" * 60 + "\n")
                    f.write(f"Node ID: {template_data['node_id']}\n")
                    f.write(f"Node Name: {template_data['node_name']}\n")
                    f.write(f"Harm Category: {template_data['category']}\n")
                    f.write(f"Domain Info: {template_data['domain_info']}\n")
                    if template_data['wikipedia_description']:
                        f.write(f"Wikipedia Description: {template_data['wikipedia_description']}\n")
                    if template_data['wikipedia_summary']:
                        f.write(f"Wikipedia Summary: {template_data['wikipedia_summary']}\n")
                    f.write(f"Generation Time: {template_data['timestamp']}\n")
                    f.write(f"\n--- FULL GENERATION PROMPT ---\n")
                    f.write(f"{template_data['prompt_template']}\n")
                    f.write("\n" + "="*80 + "\n\n")
            
            self.log_detail(f"📄 Prompt template text version saved: {templates_txt_path}")

def main():
    parser = argparse.ArgumentParser(description="Step 2: 有害提示生成测试（使用真实知识图谱数据）")
    parser.add_argument("--domain", default="medicine",
                      choices=["medicine", "finance", "education", "law"],
                      help="测试领域")
    parser.add_argument("--selected-nodes-count", type=int, default=-1,
                      help="随机选择的知识图谱节点数量 (默认: -1 表示使用所有节点，设置正数限制节点数量)")
    parser.add_argument("--verbose", action="store_true",
                      help="显示详细信息")
    parser.add_argument("--check-service-only", action="store_true",
                      help="仅检查Model service available性")
    parser.add_argument("--output-dir", type=str,
                      help="输出目录（由测试运行器传入）")
    
    args = parser.parse_args()
    
    tester = HarmfulGenerationTester(
        domain=args.domain,
        verbose=args.verbose,
        selected_nodes_count=args.selected_nodes_count
    )
    
    success = tester.run_all_tests(check_service_only=args.check_service_only)
    
    # 移除所有print语句，结果通过日志记录
    sys.exit(0 if success else 1)

if __name__ == "__main__":
    main()
