#!/usr/bin/env python
"""
Step 3: 毒性评估测试
====================

测试使用Granite Guardian评估有害提示的毒性：
- 检查Granite Guardian服务可用性 
- 加载之前生成的有害提示
- 调用Granite Guardian进行毒性评估
- 验证评估结果格式和合理性
- 保存评估后的数据

注意: 此步骤需要切换到Granite Guardian服务器 (端口8001)
     请先停止step3的微调模型，再启动Granite Guardian

使用方法:
    python step3_toxicity_evaluation.py --domain medicine
    python step3_toxicity_evaluation.py --domain finance --test-size 10
    python step3_toxicity_evaluation.py --domain medicine --check-service-only

"""

import argparse
import sys
import os
import time
import json
import requests
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

# 添加项目根目录到Python路径
SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.utils.logger_utils import get_logger
from src.modules.granite_guardian_evaluator import GraniteGuardianEvaluator

class ToxicityEvaluationTester:
    """毒性评估测试器"""
    
    def __init__(self, domain: str = "medicine", test_size: int = None, verbose: bool = False):
        self.domain = domain
        self.test_size = test_size
        self.verbose = verbose
        self.logger = get_logger(self.__class__.__name__)
        
        # 加载领域配置
        try:
            from src.utils.config_manager import ConfigManager
            config_manager = ConfigManager()
            self.domain_config = config_manager.load_domain_config(domain)
        except Exception as e:
            self.logger.warning(f"无法加载领域配置: {e}")
            self.domain_config = None
        
        # Granite Guardian服务配置
        self.model_config = {
            "base_url": "http://localhost:8001/v1",
            "model_name": "granite-guardian-3.1-8b-vllm-server",
        }

        
        self.test_results = {}
        self.granite_evaluator = None
    
    def check_granite_service(self) -> bool:
        """检查Granite Guardian服务可用性"""
        self.logger.info("🛡️  检查Granite Guardian服务...")
        
        try:
            # 从配置获取超时时间
            timeout = self.domain_config.get('test_params', {}).get('request_timeout', 10) if self.domain_config else 10
            
            # 检查服务健康状态
            health_url = f"{self.model_config['base_url']}/models"
            response = requests.get(health_url, timeout=timeout)
            
            if response.status_code == 200:
                models_data = response.json()
                available_models = [model["id"] for model in models_data.get("data", [])]
                
                if self.model_config["model_name"] in available_models:
                    self.logger.info(f"  ✅ Granite Guardian服务可用: {self.model_config['model_name']}")
                    self.logger.info(f"  📊 可用模型: {available_models}")
                    
                    self.test_results["granite_service"] = {
                        "status": True,
                        "available_models": available_models,
                        "target_model": self.model_config["model_name"]
                    }
                    return True
                else:
                    self.logger.error(f"  ❌ Granite Guardian模型不可用: {self.model_config['model_name']}")
                    self.logger.error(f"  📊 可用模型: {available_models}")
                    return False
            else:
                self.logger.error(f"  ❌ Granite Guardian服务不可用: HTTP {response.status_code}")
                return False
                
        except requests.RequestException as e:
            self.logger.error(f"  ❌ Granite Guardian服务连接失败: {e}")
            self.logger.error(f"  💡 请确保运行: bash scripts/start_vllm_granite_guardian_server.sh")
            self.test_results["granite_service"] = {"status": False, "error": str(e)}
            return False
    
    def test_granite_initialization(self) -> bool:
        """测试Granite Guardian初始化"""
        self.logger.info("🔧 测试Granite Guardian初始化...")
        
        try:
            # 初始化LLM管理器
            from src.models.llm_factory import LLMManager
            llm_manager = LLMManager()
            
            # 初始化Granite Guardian评估器
            self.granite_evaluator = GraniteGuardianEvaluator(llm_manager=llm_manager)
            
            # 测试简单的安全性评估
            test_prompt = "Tell me about the weather today."
            result = self.granite_evaluator.evaluate_prompt(test_prompt)
            
            if result and hasattr(result, 'is_harmful'):
                self.logger.info("  ✅ Granite Guardian初始化成功")
                if self.verbose:
                    self.logger.info(f"  📊 测试评估结果: 有害={result.is_harmful}, 毒性分数={result.toxicity_score:.3f}")
                
                self.test_results["granite_initialization"] = {
                    "status": True,
                    "test_result": {
                        "is_harmful": result.is_harmful,
                        "toxicity_score": result.toxicity_score
                    }
                }
                return True
            else:
                self.logger.error("  ❌ Granite Guardian评估失败")
                return False
                
        except Exception as e:
            self.logger.error(f"  ❌ Granite Guardian初始化失败: {e}")
            self.test_results["granite_initialization"] = {"status": False, "error": str(e)}
            return False
    
    def load_generated_prompts(self) -> List[Dict]:
        """加载之前生成的有害提示"""
        self.logger.info("📂 加载生成的有害提示...")
        
        # 尝试从step2的输出中加载
        prompts_path = SCRIPT_DIR / "outputs" / self.domain / "step2_generated_prompts.json"
        
        if prompts_path.exists():
            try:
                with open(prompts_path, 'r', encoding='utf-8') as f:
                    prompts = json.load(f)
                
                self.logger.info(f"  ✅ 从step2加载了 {len(prompts)} 个提示")
                
                # 转换数据格式以适配评估器，保留完整节点信息
                formatted_prompts = []
                for prompt_data in prompts:
                    formatted_prompt = {
                        "prompt_id": prompt_data.get("prompt_id", f"unknown_{len(formatted_prompts)}"),
                        "prompt": prompt_data.get("prompt_text", ""),
                        "entity": prompt_data.get("node_name", "unknown"),
                        "category": prompt_data.get("category", "unknown"),
                        "behavior_type": prompt_data.get("behavior_type", "unknown"),  # 新增：行为类型
                        "node_id": prompt_data.get("node_id", "unknown"),
                        "wikidata_uri": prompt_data.get("wikidata_uri", "unknown"),  # 修改：使用wikidata_uri
                        "timestamp": prompt_data.get("timestamp", time.strftime("%Y-%m-%d %H:%M:%S")),
                        # 保留节点的详细信息
                        "node_info": {
                            "id": prompt_data.get("node_id", "unknown"),
                            "name": prompt_data.get("node_name", "unknown"),
                            "wikidata_uri": prompt_data.get("wikidata_uri", "unknown"),
                            "description": prompt_data.get("wikipedia_description", ""),
                            "domain": prompt_data.get("domain", "unknown")
                        }
                    }
                    formatted_prompts.append(formatted_prompt)
                
                return formatted_prompts if self.test_size is None else formatted_prompts[:self.test_size]  # 如果没有指定test_size，使用全部数据
                
            except Exception as e:
                self.logger.warning(f"  ⚠️  加载step2提示失败: {e}")
        
        # 如果没有可用的提示，返回空列表
        self.logger.error(f"  ❌ 没有找到可用的提示文件: {prompts_path}")
        self.logger.error(f"  💡 请先运行step2生成有害提示")
        return []
    
    def test_toxicity_evaluation(self) -> bool:
        """测试毒性评估"""
        self.logger.info("🔍 测试毒性评估...")
        
        if not self.granite_evaluator:
            self.logger.error("  ❌ Granite Guardian评估器未初始化")
            return False
        
        # 加载要评估的提示
        prompts_to_evaluate = self.load_generated_prompts()
        
        if not prompts_to_evaluate:
            self.logger.error("  ❌ 没有可评估的提示")
            return False
        
        try:
            evaluated_prompts = []
            
            for i, prompt_data in enumerate(prompts_to_evaluate):
                prompt_text = prompt_data["prompt"]
                prompt_id = prompt_data.get("prompt_id", f"prompt_{i+1}")
                
                self.logger.info(f"  🔍 评估提示 {i+1}/{len(prompts_to_evaluate)}: {prompt_id}")
                
                try:
                    # 进行毒性评估
                    evaluation_result = self.granite_evaluator.evaluate_prompt(prompt_text)
                    
                    if evaluation_result:
                        evaluated_prompt = {
                            **prompt_data,  # 保留原始数据
                            "toxicity_evaluation": {
                                "is_harmful": evaluation_result.is_harmful,
                                "toxicity_score": evaluation_result.toxicity_score,
                                "reasoning": evaluation_result.reasoning,
                                "raw_response": evaluation_result.raw_response
                            },
                            "evaluation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                        }
                        
                        evaluated_prompts.append(evaluated_prompt)
                        
                        if self.verbose:
                            is_harmful = evaluation_result.is_harmful
                            toxicity_score = evaluation_result.toxicity_score
                            self.logger.info(f"    📊 有害性: {is_harmful}, 毒性分数: {toxicity_score:.3f}")
                    
                    time.sleep(1)  # 避免过于频繁的请求
                    
                except Exception as e:
                    self.logger.warning(f"    ⚠️  提示 {prompt_id} 评估失败: {e}")
                    # 即使评估失败，也记录原始提示
                    evaluated_prompts.append({
                        **prompt_data,
                        "toxicity_evaluation": {"error": str(e)},
                        "evaluation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                    })
                    continue
            
            if evaluated_prompts:
                self.logger.info(f"  ✅ 成功评估 {len(evaluated_prompts)} 个提示")
                self.test_results["toxicity_evaluation"] = {
                    "status": True,
                    "evaluated_count": len(evaluated_prompts),
                    "prompts": evaluated_prompts
                }
                return True
            else:
                self.logger.error("  ❌ 未成功评估任何提示")
                self.test_results["toxicity_evaluation"] = {"status": False, "evaluated_count": 0}
                return False
                
        except Exception as e:
            self.logger.error(f"  ❌ 毒性评估测试失败: {e}")
            self.test_results["toxicity_evaluation"] = {"status": False, "error": str(e)}
            return False
    
    def test_evaluation_analysis(self) -> bool:
        """测试评估结果分析"""
        self.logger.info("📊 测试评估结果分析...")
        
        if ("toxicity_evaluation" not in self.test_results or 
            not self.test_results["toxicity_evaluation"]["status"]):
            self.logger.error("  ❌ 没有可分析的评估结果")
            return False
        
        try:
            evaluated_prompts = self.test_results["toxicity_evaluation"]["prompts"]
            analysis = {
                "total_prompts": len(evaluated_prompts),
                "safe_prompts": 0,
                "unsafe_prompts": 0,
                "evaluation_errors": 0
            }
            
            for prompt_data in evaluated_prompts:
                eval_result = prompt_data.get("toxicity_evaluation", {})
                
                if "error" in eval_result:
                    analysis["evaluation_errors"] += 1
                else:
                    is_harmful = eval_result.get("is_harmful")
                    if is_harmful is True:
                        analysis["unsafe_prompts"] += 1
                    elif is_harmful is False:
                        analysis["safe_prompts"] += 1
            
            # 输出分析结果
            self.logger.info(f"  📊 评估分析结果:")
            self.logger.info(f"      总提示数: {analysis['total_prompts']}")
            self.logger.info(f"      安全提示: {analysis['safe_prompts']}")
            self.logger.info(f"      不安全提示: {analysis['unsafe_prompts']}")
            self.logger.info(f"      评估错误: {analysis['evaluation_errors']}")
            
            # 检查是否成功完成分析
            success_rate = (analysis["safe_prompts"] + analysis["unsafe_prompts"]) / analysis["total_prompts"]
            success = success_rate > 0  # 只要有评估结果就算成功
            
            if success:
                self.logger.info("  ✅ 评估结果分析完成")
            else:
                self.logger.error("  ❌ 评估结果分析失败：没有有效的评估结果")
            
            self.test_results["evaluation_analysis"] = {
                "status": success,
                "analysis": analysis,
                "success_rate": success_rate
            }
            return success
            
        except Exception as e:
            self.logger.error(f"  ❌ 评估结果分析失败: {e}")
            self.test_results["evaluation_analysis"] = {"status": False, "error": str(e)}
            return False
    
    def run_all_tests(self, check_service_only: bool = False) -> bool:
        """运行所有毒性评估测试"""
        test_size_info = f"全部数据" if self.test_size is None else str(self.test_size)
        self.logger.info(f"\n🛡️  开始毒性评估测试 (领域: {self.domain}, 大小: {test_size_info})")
        self.logger.info("="*60)
        
        tests = [
            ("Granite Guardian服务检查", self.check_granite_service),
        ]
        
        if not check_service_only:
            tests.extend([
                ("Granite Guardian初始化", self.test_granite_initialization),
                ("毒性评估", self.test_toxicity_evaluation),
                ("评估结果分析", self.test_evaluation_analysis)
            ])
        
        passed_tests = 0
        total_tests = len(tests)
        
        for test_name, test_func in tests:
            self.logger.info(f"\n📋 执行测试: {test_name}")
            try:
                success = test_func()
                if success:
                    passed_tests += 1
                elif test_name == "Granite Guardian服务检查":
                    # 如果Granite Guardian服务不可用，提前退出
                    self.logger.error("💔 Granite Guardian服务不可用，无法继续测试")
                    break
            except Exception as e:
                self.logger.error(f"测试 '{test_name}' 时发生异常: {e}")
        
        # 生成测试总结
        self.logger.info(f"\n{'='*60}")
        self.logger.info(f"📊 毒性评估测试总结")
        self.logger.info(f"✅ 通过: {passed_tests}/{total_tests}")
        
        if check_service_only:
            overall_status = passed_tests >= 1
            if overall_status:
                self.logger.info("🎉 Granite Guardian服务检查通过，可以运行完整测试")
            else:
                self.logger.error("❌ Granite Guardian服务不可用，请启动服务后重试")
        else:
            overall_status = passed_tests > 0
            
            if overall_status:
                self.logger.info("🎉 毒性评估测试完成")
            else:
                self.logger.error("❌ 毒性评估测试失败")
        
        # 保存测试结果
        self._save_test_results(overall_status, passed_tests, total_tests)
        
        return overall_status
    
    def _save_test_results(self, overall_status: bool, passed: int, total: int):
        """保存测试结果到文件"""
        report = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "domain": self.domain,
            "test_size": self.test_size,
            "model_config": self.model_config,
            "overall_status": overall_status,
            "summary": {
                "passed_tests": passed,
                "total_tests": total,
                "success_rate": passed / total if total > 0 else 0
            },
            "detailed_results": self.test_results
        }
        
        output_dir = SCRIPT_DIR / "outputs" / self.domain
        output_dir.mkdir(parents=True, exist_ok=True)
        
        report_path = output_dir / "step3_toxicity_evaluation_report.json"
        with open(report_path, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=2, ensure_ascii=False)
        
        self.logger.info(f"📋 毒性评估测试报告已保存: {report_path}")
        
        # 如果有评估的提示，单独保存
        if ("toxicity_evaluation" in self.test_results and 
            self.test_results["toxicity_evaluation"]["status"]):
            evaluated_path = output_dir / "step3_evaluated_prompts.json"
            with open(evaluated_path, 'w', encoding='utf-8') as f:
                json.dump(
                    self.test_results["toxicity_evaluation"]["prompts"], 
                    f, indent=2, ensure_ascii=False
                )
            self.logger.info(f"📝 评估后的提示已保存: {evaluated_path}")

def main():
    parser = argparse.ArgumentParser(description="Step 3: 毒性评估测试")
    parser.add_argument("--domain", default="medicine",
                      choices=["medicine", "finance", "education", "law"],
                      help="测试领域")
    parser.add_argument("--test-size", type=int, default=None,
                      help="测试数据量大小 (默认为全部数据)")
    parser.add_argument("--verbose", action="store_true",
                      help="显示详细信息")
    parser.add_argument("--check-service-only", action="store_true",
                      help="仅检查Granite Guardian服务可用性")
    parser.add_argument("--output-dir", type=str,
                      help="输出目录（由测试运行器传入）")
    
    args = parser.parse_args()
    
    tester = ToxicityEvaluationTester(
        domain=args.domain,
        test_size=args.test_size,
        verbose=args.verbose
    )
    
    success = tester.run_all_tests(check_service_only=args.check_service_only)
    
    if success:
        if args.check_service_only:
            print(f"\n✅ Step 3 服务检查: Granite Guardian服务可用")
        else:
            print(f"\n✅ Step 3 完成: 毒性评估测试通过")
    else:
        if args.check_service_only:
            print(f"\n❌ Step 3 服务检查: Granite Guardian服务不可用")
        else:
            print(f"\n❌ Step 3 失败: 毒性评估测试未通过")
    
    sys.exit(0 if success else 1)

if __name__ == "__main__":
    main()
