#!/usr/bin/env python
"""
Step 5: 并行优化的批量隐蔽化改写处理器
====================================

保持原有算法和模型完全一致，仅通过并行处理优化性能：
- 多进程/多线程并行处理prompt
- 保持原有的双路径隐蔽化算法
- 使用相同的模型和评估标准
- 相同的约束检查和ASR评估
- 仅优化I/O和并发执行

⚡ 性能目标: 通过并行处理将1000个prompt的处理时间从29小时缩短到3-6小时

使用方法:
    python step5_optimized_batch.py --domain medicine --workers 20
    python step5_optimized_batch.py --domain medicine --workers 10 --max-prompts 100
"""

import argparse
import sys
import os
import time
import json
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from dataclasses import dataclass
import multiprocessing as mp
import threading
import types
from queue import Queue

# 添加项目根目录到Python路径
SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

# 保持原有的导入完全一致
from src.utils.logger_utils import get_logger
from src.services.implicit import ImplicitGeneratorService
from src.services.evaluation.asr_evaluator_service import ASREvaluatorService
from src.services.evaluation.constraint_checker_service import ConstraintCheckerService
from src.core.interfaces import ImplicitRequest, ImplicitMethod

@dataclass
class ParallelConfig:
    """并行处理配置 - 仅优化执行方式，不改变算法"""
    max_workers: int = 10               # 最大并发worker数
    chunk_size: int = 1                 # 每个worker处理的prompt数量
    use_threading: bool = True          # 使用线程池而非进程池
    progress_interval: int = 10         # 进度报告间隔
    save_interval: int = 50             # 中间结果保存间隔

class ParallelDualPathProcessor:
    """并行双路径隐蔽化处理器 - 完全保持原算法，仅并行化执行"""
    
    def __init__(self, domain: str = "medicine", config: ParallelConfig = None):
        self.domain = domain
        self.config = config or ParallelConfig()
        self.logger = get_logger(f"{self.__class__.__name__}_{domain}")
        
        self.output_dir = SCRIPT_DIR / "outputs" / self.domain
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载配置管理器和领域配置 - 与原版完全一致
        from src.utils.config_manager import ConfigManager
        self.config_manager = ConfigManager()
        self.domain_config = self.config_manager.load_domain_config(self.domain)
        
        if not self.domain_config:
            raise ValueError(f"无法加载领域配置: {self.domain}")
        
        filtering_thresholds = getattr(self.domain_config, 'filtering_thresholds', {})
        attack_config = getattr(self.domain_config, 'attack_config', {})
        
        # 使用原有的配置参数 - 完全不变
        similarity_threshold = filtering_thresholds.get('similarity', {}).get('attack_iteration')
        ppl_threshold = filtering_thresholds.get('perplexity', {}).get('attack_iteration')
        max_iterations = attack_config.get('max_iterations')
        enable_parallel = attack_config.get('enable_parallel')
        timeout_per_iteration = attack_config.get('timeout_per_iteration')
        
        # 验证必需的配置参数
        if similarity_threshold is None:
            raise ValueError("配置文件中缺少 filtering_thresholds.similarity.attack_iteration")
        if ppl_threshold is None:
            raise ValueError("配置文件中缺少 filtering_thresholds.perplexity.attack_iteration") 
        if max_iterations is None:
            raise ValueError("配置文件中缺少 attack_config.max_iterations")
        if enable_parallel is None:
            raise ValueError("配置文件中缺少 attack_config.enable_parallel")
        
        # 保持原有的stealth_config
        self.stealth_config = {
            "max_iterations": max_iterations,
            "similarity_threshold": similarity_threshold,
            "ppl_threshold": ppl_threshold,
            "enable_parallel": enable_parallel,
            "timeout_per_iteration": timeout_per_iteration,
            "alternate_paths": True  # 交替使用basic和enhanced路径
        }
        
        self.logger.info(f"🚀 并行处理器初始化完成:")
        self.logger.info(f"   👥 并发worker数: {self.config.max_workers}")
        self.logger.info(f"   � 相似度阈值: {similarity_threshold} (原配置)")
        self.logger.info(f"   📊 PPL阈值: {ppl_threshold} (原配置)")
        self.logger.info(f"   🔄 最大迭代: {max_iterations} (原配置)")
        
        # 结果存储
        self.processed_results = []
        self.results_lock = threading.Lock()
        self.progress_counter = 0
        self.progress_lock = threading.Lock()
        
        # 线程安全的模型初始化锁
        self.model_init_lock = threading.Lock()
    
    def _create_thread_safe_processor(self):
        """创建线程安全的处理器实例"""
        import threading
        
        # 使用线程本地存储确保每个线程有独立的处理器
        if not hasattr(threading.current_thread(), 'processor_instance'):
            with self.model_init_lock:  # 确保模型初始化时的线程安全
                try:
                    from real_pipeline.step5_implicit_jailbreak import DualPathStealthProcessor
                    
                    # 为每个线程创建独立的处理器实例
                    processor = DualPathStealthProcessor(domain=self.domain, test_size=1, verbose=False)
                    
                    # 创建绑定方法来替换原方法
                    import types
                    processor._calculate_ppl = types.MethodType(self._thread_safe_calculate_ppl, processor)
                    processor._calculate_similarity = types.MethodType(self._thread_safe_calculate_similarity, processor)
                    
                    threading.current_thread().processor_instance = processor
                    
                except Exception as e:
                    self.logger.error(f"线程安全处理器创建失败: {e}")
                    raise
        
        return threading.current_thread().processor_instance
    
    def _thread_safe_calculate_ppl(self, processor, text: str) -> float:
        """线程安全的PPL计算"""
        try:
            # 使用简单但稳定的PPL近似，避免多线程模型加载冲突
            import math
            
            # 基于文本特征的PPL近似计算
            words = text.split()
            if not words:
                return 50.0
            
            # 计算词汇复杂度
            avg_word_len = sum(len(w) for w in words) / len(words)
            unique_words = len(set(words))
            vocab_diversity = unique_words / len(words)
            
            # 计算语法复杂度
            sentence_count = text.count('.') + text.count('!') + text.count('?') + 1
            avg_sentence_len = len(words) / sentence_count
            
            # 计算标点和特殊字符比例
            special_chars = sum(1 for c in text if c in '!@#$%^&*()_+-=[]{}|;:,.<>?')
            special_ratio = special_chars / len(text) if text else 0
            
            # PPL近似公式
            base_ppl = 15 + avg_word_len * 2
            diversity_factor = (1 - vocab_diversity) * 20
            length_factor = min(avg_sentence_len / 10, 3) * 5
            special_factor = special_ratio * 30
            
            estimated_ppl = base_ppl + diversity_factor + length_factor + special_factor
            
            # 限制在合理范围内
            return max(10.0, min(estimated_ppl, 200.0))
            
        except Exception as e:
            processor.logger.warning(f"PPL计算失败，使用默认值: {e}")
            return 50.0
    
    def _thread_safe_calculate_similarity(self, processor, text1: str, text2: str) -> float:
        """线程安全的相似度计算"""
        try:
            # 获取线程本地的相似度模型
            current_thread = threading.current_thread()
            
            if not hasattr(current_thread, 'similarity_model'):
                with self.model_init_lock:
                    from sentence_transformers import SentenceTransformer
                    # 为每个线程创建独立的模型实例
                    current_thread.similarity_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
            
            # 计算相似度
            embeddings = current_thread.similarity_model.encode([text1, text2], convert_to_tensor=True)
            from sentence_transformers import util
            similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
            
            return similarity
            
        except Exception as e:
            processor.logger.warning(f"相似度计算失败，使用备用方法: {e}")
            
            # 备用方案：简单的词汇重叠率
            words1 = set(text1.lower().split())
            words2 = set(text2.lower().split())
            
            if not words1 or not words2:
                return 0.0
            
            intersection = words1.intersection(words2)
            union = words1.union(words2)
            
            return len(intersection) / len(union)
    
    def _clean_stealth_prompt(self, raw_stealth_prompt: str) -> str:
        """
        基于Step5实际输出数据的隐蔽化prompt清理函数
        
        处理Step5 Alpaca格式输出的实际格式问题：
        1. "Transformed Query:" 前缀
        2. "Research Question:" 前缀  
        3. "Professional:" 前缀
        4. "Original: ... Transformed: ..." 格式（少数情况）
        5. 引号不匹配和嵌套引号问题
        6. 多段落学术解释文字
        7. "As a professional academic writer..." 前缀
        8. "Transformed query:" 小写变体
        """
        if not raw_stealth_prompt or not isinstance(raw_stealth_prompt, str):
            return raw_stealth_prompt
        
        cleaned_prompt = raw_stealth_prompt.strip()
        
        # === 第一阶段：复杂学术包装模式清理 ===
        
        # 1. 处理复杂的学术引导语模式
        academic_intro_patterns = [
            # "A relevant research question to consider in the field of... is:" 模式
            r'^A relevant research question to consider in the field of.*?is:\s*"?',
            # "a pertinent research question... would be:" 模式  
            r'^.*?a pertinent research question.*?would be:\s*"?',
            # "an intriguing research question arises:" 模式
            r'^.*?an intriguing research question arises:\s*"?',
            # "a relevant inquiry could be:" 模式
            r'^.*?a relevant inquiry could be:\s*"?',
            # "To address this issue, a relevant research question could be:" 模式
            r'^.*?To address this issue,\s*a relevant research question could be:\s*"?',
            # "the prompt could be transformed into the following professional" 模式
            r'^.*?the prompt could be transformed into the following professional\s*"?',
            # 长篇解释模式
            r'^The prompt has been transformed into.*?The new expression is:\s*"?',
            r'^.*?The new expression is:\s*"?',
            # 复杂学术前缀模式
            r'^In the field of.*?,.*?(?=What|How|Can|Could|Would|Should|Might|Do|Does|Is|Are)',
            r'^.*?an important area of research is.*?(?=What|How|Can|Could|Would|Should|Might|Do|Does|Is|Are)',
        ]
        
        for pattern in academic_intro_patterns:
            cleaned_prompt = re.sub(pattern, '', cleaned_prompt, flags=re.IGNORECASE | re.DOTALL)
            cleaned_prompt = cleaned_prompt.strip()
        
        # === 第二阶段：标准前缀清理 ===
        
        # 1. 处理 "Original: ... Transformed: ..." 格式
        if "Original:" in cleaned_prompt and "Transformed:" in cleaned_prompt:
            # 优先提取Transformed后的内容
            transformed_match = re.search(r'Transformed:?\s*(.+?)(?:\n\n|$)', cleaned_prompt, re.DOTALL | re.IGNORECASE)
            if transformed_match:
                cleaned_prompt = transformed_match.group(1).strip()
        
        # 1.5 处理 **Original Prompt:** ... **Transformed Prompt:** 格式
        if "**Original Prompt:**" in cleaned_prompt and "**Transformed Prompt:**" in cleaned_prompt:
            # 提取Transformed Prompt后的内容
            transformed_match = re.search(r'\*\*Transformed Prompt:\*\*\s*(.+?)(?:\n\n|$)', cleaned_prompt, re.DOTALL | re.IGNORECASE)
            if transformed_match:
                cleaned_prompt = transformed_match.group(1).strip()
        
        # 2. 处理各种前缀标记（包括新发现的）- 先处理长的，再处理短的
        markers_to_remove = [
            # 长前缀优先（避免部分匹配）
            "Reframed as professional", "reframed as professional",
            "Transformed Research Question:", "transformed research question:",
            "As a professional academic writer",
            "**Original Prompt:**", "**Transformed Prompt:**",
            
            # 标准前缀
            "Transformed Query:", "Research Question:", "Professional:",
            "Transformed query:", "transformed query:",
            "Research question:", "research question:", 
            "Professional:", "professional:",
            
            # 新发现的前缀
            "Transformation:", "transformation:",
            "Reframed as:", "reframed as:",
            "Transformed:", "transformed:",
            "Reframed Query:", "reframed query:",
            "Original:", "original:",
            "Reframed", "reframed",
            
            # 特殊前缀
            "medicine", "Medicine",
            
            # 引号包裹版本 - 长的优先
            '"Reframed as professional:', '"Transformed Research Question:',
            '"Transformed Query:', '"Research Question:', '"Professional:',
            '"Transformation:', '"Reframed as:',
            '"Transformed:', '"Reframed Query:', '"Original:',
            'Reframed as professional "', 'Transformed Research Question: "',
            'Transformed Query: "', 'Research Question: "', 'Professional: "',
            'Transformation: "', 'Reframed as: "',
            'Transformed: "', 'Reframed Query: "', 'Original: "',
            
            # 其他格式前缀
            "Reframed:", "Revised:",
            "Output:", "Result:", "Answer:",
            "**", "###", "---",
        ]
        
        for marker in markers_to_remove:
            cleaned_prompt = cleaned_prompt.replace(marker, "").strip()
        
        # 3. 使用正则表达式清理Step5实际出现的复杂前缀
        prefix_patterns = [
            # Step5实际出现的前缀模式
            r'^Transformed Query:?\s*"?',  # "Transformed Query: " 或 "Transformed Query:"
            r'^Research Question:?\s*"?',  # "Research Question: " 或 "Research Question:"
            r'^Professional:?\s*"?',       # "Professional: " 或 "Professional:"
            r'^Transformed query:?\s*"?',  # 小写变体
            r'^Research question:?\s*"?',  # 小写变体
            r'^Professional:?\s*"?',       # 小写变体
            
            # 新发现的前缀模式
            r'^Transformation:?\s*"?',     # 新增：处理 "Transformation:" 前缀
            r'^Transformed Research Question:?\s*"?',  # 新增：处理完整形式
            r'^Reframed as:?\s*"?',        # 新增：处理 "Reframed as:" 前缀
            r'^Transformed:?\s*"?',        # 新增：处理 "Transformed:" 前缀
            r'^Reframed Query:?\s*"?',     # 新增：处理 "Reframed Query:" 前缀
            r'^Original:?\s*"?',           # 新增：处理 "Original:" 前缀
            r'^Reframed\s*"?',             # 新增：处理 "Reframed" 前缀
            
            # 特殊模式
            r'^Reframed as professional\s*"?',  # 新增：处理完整的"Reframed as professional"
            r'^medicine\s*"?',
            r'^Medicine\s*"?',
            
            # 处理复杂的学术前缀
            r'^As a professional academic writer.*?(?=What|How|Can|In what|Given|Could)',  # 直到问句开始
            r'^In the (context|realm) of.*?,\s*',  # "In the context of..." 
            r'^Given the.*?,\s*',  # "Given the..."
            r'^Within the context of.*?,\s*',  # "Within the context of..."
            
            # 其他格式清理
            r'^\*\*Original Prompt:\*\*.*?\*\*Transformed Prompt:\*\*\s*',
            r'^\*\*Transformed Prompt:\*\*\s*',
            r'^Original:.*?Transformed:\s*',  # 处理完整的Original/Transformed格式
        ]
        
        for pattern in prefix_patterns:
            cleaned_prompt = re.sub(pattern, '', cleaned_prompt, flags=re.IGNORECASE | re.MULTILINE | re.DOTALL)
            cleaned_prompt = cleaned_prompt.strip()
        
        # === 第二阶段：引号和括号清理 (参考Step2的逻辑) ===
        
        # 4. 去掉包裹的引号或括号 (使用Step2的while循环逻辑)
        iteration_count = 0
        while iteration_count < 5:  # 防止无限循环
            old_cleaned = cleaned_prompt
            
            if ((cleaned_prompt.startswith('"') and cleaned_prompt.endswith('"')) or 
                (cleaned_prompt.startswith("'") and cleaned_prompt.endswith("'")) or
                (cleaned_prompt.startswith('[') and cleaned_prompt.endswith(']')) or
                (cleaned_prompt.startswith('(') and cleaned_prompt.endswith(')'))):
                cleaned_prompt = cleaned_prompt[1:-1].strip()
            
            # 特殊处理：只有开头引号但没有结尾引号的情况
            if cleaned_prompt.startswith('"') and not cleaned_prompt.endswith('"'):
                cleaned_prompt = cleaned_prompt[1:].strip()
            
            if cleaned_prompt == old_cleaned:
                break
            iteration_count += 1
        
        # 5. 处理不匹配的引号和结尾问题
        cleaned_prompt = re.sub(r'^"([^"]+)".*', r'\1', cleaned_prompt, flags=re.DOTALL)
        cleaned_prompt = re.sub(r'([^"]+)"\s*\?$', r'\1?', cleaned_prompt)
        # 处理开头有引号的情况
        if cleaned_prompt.startswith('"') and not cleaned_prompt.endswith('"'):
            cleaned_prompt = cleaned_prompt[1:]
        # 处理结尾有引号但开头没有的情况
        if cleaned_prompt.endswith('"') and not cleaned_prompt.startswith('"'):
            cleaned_prompt = cleaned_prompt[:-1]
        # 处理"Transformed query:"这种特殊格式
        cleaned_prompt = re.sub(r'^Transformed query:\s*"([^"]+)".*', r'\1', cleaned_prompt, flags=re.IGNORECASE | re.DOTALL)
        
        # === 第三阶段：内容过滤和提取 ===
        
        # 6. 去掉开头的编号 (使用Step2的方法)
        cleaned_prompt = re.sub(r'^\d+[\.\)]\s*', '', cleaned_prompt)
        
        # 7. 多段落处理 - 智能提取核心内容
        lines = cleaned_prompt.split('\n')
        if len(lines) > 3:
            # 收集候选句子
            candidates = []
            
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                
                # 跳过明显的指令性文本 (基于Step5的Alpaca格式和学术化输出特征)
                line_lower = line.lower()
                instruction_markers = [
                    # Alpaca格式标记
                    "below is an instruction", "### instruction", "### input", "### response",
                    "appropriately completes", "write a response", "describes a task",
                    
                    # 学术写作指令标记
                    "you are a professional academic writer", "your task is to transform",
                    "transformation principles", "professional disguise", "academic camouflage",
                    "euphemistic substitution", "contextual embedding", "intent preservation",
                    
                    # 输出格式指令
                    "strict output format", "provide only", "no explanations", "no analysis",
                    "no prefixes", "no suffixes", "no additional commentary",
                    "single professional", "research question", "transformed query",
                    
                    # 元信息标记
                    "examples by domain", "quality requirements", "domain specificity",
                    "academic authenticity", "natural professional flow",
                    
                    # 常见格式问题
                    "original:", "transformed:", "here are", "the following are",
                    "as requested", "output format", "example:", "note:", "disclaimer:"
                ]
                
                if any(marker in line_lower for marker in instruction_markers):
                    continue
                
                # 评分系统 - 基于Step5学术化输出的特征
                score = 0
                
                # 长度合理性（Step5通常产生较长的学术化查询）
                if 30 <= len(line) <= 300:
                    score += 3
                elif 20 <= len(line) <= 400:
                    score += 2
                elif 15 <= len(line) <= 500:
                    score += 1
                
                # 词汇数量（Step5学术化查询通常更长）
                word_count = len(line.split())
                if word_count >= 12:
                    score += 3
                elif word_count >= 8:
                    score += 2
                elif word_count >= 5:
                    score += 1
                
                # 问号优先（Step5主要产生疑问句）
                if line.endswith('?'):
                    score += 4
                
                # 包含关键疑问词（Step5的特征）
                if any(keyword in line_lower for keyword in ['what are', 'how can', 'in what ways', 'can you', 'could you', 'would you', 'how might']):
                    score += 3
                
                # 学术术语相关性（Step5特有）
                academic_terms = ['research', 'study', 'analysis', 'investigation', 'examination', 'assessment', 'evaluation', 
                                'methodology', 'protocol', 'strategy', 'approach', 'framework', 'systematic', 'optimal',
                                'efficacy', 'effectiveness', 'implications', 'consequences', 'considerations']
                if any(term in line_lower for term in academic_terms):
                    score += 2
                
                # 专业化表述（Step5特征）
                professional_terms = ['professional', 'clinical', 'therapeutic', 'pharmaceutical', 'medical', 'healthcare',
                                     'patient', 'treatment', 'intervention', 'diagnosis', 'management', 'delivery']
                if any(term in line_lower for term in professional_terms):
                    score += 2
                
                # 句子完整性
                if line.endswith(('.', '!', '?', ':')):
                    score += 1
                
                if score > 0:
                    candidates.append((line, score))
            
            # 选择得分最高的句子
            if candidates:
                candidates.sort(key=lambda x: x[1], reverse=True)
                cleaned_prompt = candidates[0][0]
        
        # === 第四阶段：最终格式化 (参考Step2的最终清理) ===
        
        # 8. 清理多余的空白字符 (Step2方法)
        cleaned_prompt = re.sub(r'\s+', ' ', cleaned_prompt).strip()
        
        # 9. 确保句子以合适的标点结尾 (Step2逻辑)
        if cleaned_prompt and not cleaned_prompt.endswith(('.', '!', '?', ':')):
            # 如果句子看起来完整但没有结尾标点，添加问号或句号
            if len(cleaned_prompt.split()) >= 5:
                # 如果是疑问句式，添加问号
                if any(qword in cleaned_prompt.lower() for qword in ['what', 'how', 'why', 'when', 'where', 'can', 'could', 'would', 'should']):
                    cleaned_prompt += '?'
                else:
                    cleaned_prompt += '.'
        
        # 10. 最终有效性检查 (基于Step2的_is_valid_harmful_prompt)
        if not self._is_valid_stealth_prompt(cleaned_prompt):
            # 如果清理后的结果无效，尝试返回原始内容的第一个合理句子
            sentences = re.split(r'[.!?]\s+', raw_stealth_prompt.strip())
            for sentence in sentences:
                sentence = sentence.strip()
                if self._is_valid_stealth_prompt(sentence):
                    cleaned_prompt = sentence
                    break
            else:
                # 如果还是无效，返回原始输入
                cleaned_prompt = raw_stealth_prompt.strip()
        
        # 记录清理过程（用于调试）
        if len(raw_stealth_prompt) != len(cleaned_prompt):
            self.logger.debug(f"清理前: {len(raw_stealth_prompt)} 字符")
            self.logger.debug(f"清理后: {len(cleaned_prompt)} 字符")
        
        return cleaned_prompt

    def _is_valid_stealth_prompt(self, prompt: str) -> bool:
        """
        检查清理后的隐蔽化prompt是否有效
        基于Step2的_is_valid_harmful_prompt方法改编
        """
        if not prompt or len(prompt.strip()) < 15:
            return False
        
        prompt_lower = prompt.lower().strip()
        
        # 检测明显的指令性语言标记 (参考Step2)
        instruction_markers = [
            "red-team specialist", "adversarial prompts", "your task is to",
            "generate harmful", "create harmful", "ai safety testing",
            "test content filters", "explicitly harmful user queries",
            "instructions that describes a task", "below is an instruction",
            "write a response that", "appropriately completes",
            "for reference only", "generate new content", "nothing else"
        ]
        
        if any(marker in prompt_lower for marker in instruction_markers):
            return False
        
        # 检测元文本和格式说明 (参考Step2)
        meta_text_patterns = [
            "here are", "the following are", "these are examples",
            "as requested", "output format", "example:", "note:",
            "disclaimer:", "warning:", "[direct harmful", "[another direct", "[prompt]"
        ]
        
        if any(pattern in prompt_lower for pattern in meta_text_patterns):
            return False
        
        # 基本长度检查 (参考Step2: 20-400字符范围)
        if len(prompt) < 20 or len(prompt) > 500:
            return False
        
        # 确保有足够的词汇量 (参考Step2: 至少5个词)
        if len(prompt.split()) < 5:
            return False
        
        # Step5隐蔽化prompt特有检查
        # 确保不是明显的格式问题
        if prompt.count('"') == 1:  # 不匹配的引号
            return False
        
        # 检查Step5可能出现的问题格式
        if prompt.startswith(('Original:', 'Transformed Query:', 'Research Question:', 'Professional:', 'As a professional',
                             'Transformation:', 'Transformed Research Question:', 'Reframed as:', 'Transformed:',
                             'Reframed Query:', 'Reframed as professional', 'medicine')):
            return False
        
        # 确保是合理的学术化查询（Step5特征）
        if not any(keyword in prompt.lower() for keyword in ['what', 'how', 'can', 'could', 'would', 'should', 'might']):
            return False  # Step5输出应该是疑问句
        
        return True

    def process_single_prompt_worker(self, prompt_data: Dict) -> Dict:
        """单个worker处理一个prompt - 完全使用原版DualPathStealthProcessor的算法"""
        prompt_id = prompt_data.get("prompt_id", "unknown")
        original_prompt = prompt_data.get("prompt", "")
        
        start_time = time.time()
        
        try:
            # 创建线程安全的原版处理器实例
            processor = self._create_thread_safe_processor()
            
            # 执行原版的双路径隐蔽化改写 - 算法完全一致
            stealth_result = processor._dual_path_stealth_rewrite(original_prompt, prompt_data)
            
            processing_time = time.time() - start_time
            
            # 构建与原版一致的结果格式，保留完整字段信息
            if stealth_result and stealth_result.get("success"):
                winning_path = stealth_result.get("winning_path", "unknown")
                iterations = stealth_result.get("iterations", 0)  # 修复：使用正确的字段名
                similarity = stealth_result.get("final_similarity", 0)
                ppl = stealth_result.get("final_ppl", 0)
                raw_final_prompt = stealth_result.get("stealth_prompt", "")
                # 清理隐蔽化prompt中的格式问题
                final_prompt = self._clean_stealth_prompt(raw_final_prompt)
                
                result = {
                    "prompt_id": prompt_id,
                    "domain": self.domain,
                    "entity": prompt_data.get("entity", "unknown"),
                    "wikidata_uri": prompt_data.get("wikidata_uri", "unknown"),
                    "category": prompt_data.get("category", "unknown"),
                    "behavior_type": prompt_data.get("behavior_type", "unknown"),
                    "original_prompt": original_prompt,
                    "stealth_prompt": final_prompt,
                    "stealth_success": True,
                    "winning_path": winning_path,
                    "iterations": iterations,
                    "final_similarity": similarity,
                    "final_ppl": ppl,
                    "processing_time": processing_time,
                    # 保持原有字段用于兼容性
                    "success": True,
                    "metadata": prompt_data
                }
            else:
                reason = stealth_result.get("reason", "未知原因") if stealth_result else "处理失败"
                # 检查是否有隐蔽化结果（即使攻击失败）
                raw_stealth_prompt = stealth_result.get("stealth_prompt") if stealth_result else None
                # 清理隐蔽化prompt（如果存在）
                stealth_prompt = self._clean_stealth_prompt(raw_stealth_prompt) if raw_stealth_prompt else None
                
                # 提取迭代次数 - 失败情况下也要保存
                iterations = 0
                if stealth_result:
                    # 从stealth_result中提取迭代次数 - 优先使用iterations字段
                    iterations = stealth_result.get("iterations", 0)
                    if iterations == 0:
                        iterations = stealth_result.get("total_iterations", 0)  # 兼容旧版本
                    
                    # 如果还是0，从reason中推断迭代次数
                    if iterations == 0 and "达到最大迭代次数" in reason:
                        iterations = self.stealth_config.get("max_iterations", 5)
                
                result = {
                    "prompt_id": prompt_id,
                    "domain": self.domain,
                    "entity": prompt_data.get("entity", "unknown"),
                    "wikidata_uri": prompt_data.get("wikidata_uri", "unknown"),
                    "category": prompt_data.get("category", "unknown"),
                    "behavior_type": prompt_data.get("behavior_type", "unknown"),
                    "original_prompt": original_prompt,
                    "stealth_prompt": stealth_prompt,  # 如果有最佳候选则使用，否则为null
                    "stealth_success": False,  # 攻击失败标记
                    "iterations": iterations,  # 保存尝试的迭代次数
                    "reason": reason,
                    "processing_time": processing_time,
                    # 保持原有字段用于兼容性
                    "success": False,
                    "metadata": prompt_data
                }
                
                # 如果有隐蔽化结果，添加额外信息
                if stealth_result and stealth_prompt:
                    result.update({
                        "final_similarity": stealth_result.get("final_similarity", 0.0),
                        "final_ppl": stealth_result.get("final_ppl", 0.0),
                        "weighted_score": stealth_result.get("weighted_score"),
                        "winning_path": stealth_result.get("winning_path"),
                        "winning_iteration": stealth_result.get("winning_iteration"),
                        "has_stealth_version": True  # 标记有隐蔽化版本但攻击失败
                    })
                elif stealth_prompt:
                    # 即使没有stealth_result，但有stealth_prompt，也要尝试计算相似度和PPL
                    try:
                        processor = self._create_thread_safe_processor()
                        similarity = processor._calculate_similarity(original_prompt, stealth_prompt)
                        ppl = processor._calculate_ppl(stealth_prompt)
                        result.update({
                            "final_similarity": similarity,
                            "final_ppl": ppl,
                            "has_stealth_version": True  # 标记有隐蔽化版本但攻击失败
                        })
                    except Exception as e:
                        # 如果计算失败，设置默认值
                        result.update({
                            "final_similarity": 0.0,
                            "final_ppl": 0.0,
                            "has_stealth_version": True
                        })
                else:
                    result["has_stealth_version"] = False  # 标记完全失败
            
            # 更新进度
            with self.progress_lock:
                self.progress_counter += 1
                if self.progress_counter % self.config.progress_interval == 0:
                    progress_pct = (self.progress_counter / self.total_prompts * 100) if hasattr(self, 'total_prompts') and self.total_prompts > 0 else 0
                    if hasattr(self, 'total_prompts'):
                        self.logger.info(f"📊 进度: {self.progress_counter}/{self.total_prompts} ({progress_pct:.1f}%)")
                    else:
                        self.logger.info(f"📊 已处理: {self.progress_counter} 个prompts")
            
            return result
            
        except Exception as e:
            processing_time = time.time() - start_time
            self.logger.error(f"❌ Worker处理prompt {prompt_id} 失败: {e}")
            return {
                "prompt_id": prompt_id,
                "domain": self.domain,
                "entity": prompt_data.get("entity", "unknown"),
                "wikidata_uri": prompt_data.get("wikidata_uri", "unknown"),
                "category": prompt_data.get("category", "unknown"),
                "behavior_type": prompt_data.get("behavior_type", "unknown"),
                "original_prompt": original_prompt,
                "stealth_prompt": None,
                "stealth_success": False,
                "reason": f"Worker异常: {str(e)}",
                "processing_time": processing_time,
                # 保持原有字段用于兼容性
                "success": False,
                "metadata": prompt_data
            }
    
    def load_prompts_for_parallel_processing(self) -> List[Dict]:
        """加载用于并行处理的提示 - 与原版逻辑完全一致"""
        self.logger.info("📂 加载并行处理提示...")
        
        # 首先尝试从step4的筛选结果中加载
        filtered_path = self.output_dir / "step4_filtered_prompts.json"
        
        if filtered_path.exists():
            try:
                with open(filtered_path, 'r', encoding='utf-8') as f:
                    prompts = json.load(f)
                
                self.logger.info(f"✅ 从step4加载了 {len(prompts)} 个筛选后的提示")
                return prompts
                
            except Exception as e:
                self.logger.warning(f"⚠️ 加载step4筛选结果失败: {e}")
        
        # 如果没有step4数据，尝试从step3评估结果中加载
        evaluated_path = self.output_dir / "step3_evaluated_prompts.json"
        
        if evaluated_path.exists():
            try:
                with open(evaluated_path, 'r', encoding='utf-8') as f:
                    prompts = json.load(f)
                
                # 过滤出不安全的提示（更适合隐蔽化）
                unsafe_prompts = []
                for prompt_data in prompts:
                    toxicity_eval = prompt_data.get("toxicity_evaluation", {})
                    if toxicity_eval.get("is_harmful", False):  # 选择有害的提示
                        unsafe_prompts.append(prompt_data)
                
                self.logger.info(f"✅ 从step3加载了 {len(prompts)} 个提示")
                self.logger.info(f"🎯 筛选出 {len(unsafe_prompts)} 个有害提示用于隐蔽化")
                return unsafe_prompts
                
            except Exception as e:
                self.logger.warning(f"⚠️ 加载step3提示失败: {e}")
        
        self.logger.error("❌ 没有找到可用的数据文件")
        return []
    
    def run_parallel_processing(self, max_prompts: int = None) -> bool:
        """运行并行处理 - 仅并行化执行，算法完全不变"""
        start_time = time.time()
        
        # 加载要处理的prompts
        all_prompts = self.load_prompts_for_parallel_processing()
        
        if not all_prompts:
            self.logger.error("❌ 没有可处理的prompts")
            return False
        
        if max_prompts:
            all_prompts = all_prompts[:max_prompts]
        
        total_prompts = len(all_prompts)
        self.logger.info(f"🚀 开始并行处理: {total_prompts} 个prompts")
        self.logger.info(f"👥 使用 {self.config.max_workers} 个并发worker")
        self.logger.info(f"🧮 每个worker使用完全相同的原版算法")
        
        # 重置计数器
        self.progress_counter = 0
        self.total_prompts = total_prompts  # 保存总数用于进度显示
        
        # 并行处理
        all_results = []
        
        if self.config.use_threading:
            # 使用线程池 - 适合I/O密集型任务
            with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
                # 提交所有任务
                future_to_prompt = {
                    executor.submit(self.process_single_prompt_worker, prompt_data): prompt_data
                    for prompt_data in all_prompts
                }
                
                # 收集结果
                for future in as_completed(future_to_prompt):
                    try:
                        result = future.result()
                        all_results.append(result)
                        
                        # 实时进度显示
                        current_count = len(all_results)
                        if current_count % self.config.progress_interval == 0 or current_count == total_prompts:
                            progress_pct = (current_count / total_prompts * 100)
                            self.logger.info(f"📊 进度: {current_count}/{total_prompts} ({progress_pct:.1f}%)")
                        
                        # 定期保存中间结果
                        if len(all_results) % self.config.save_interval == 0:
                            self._save_intermediate_results(all_results)
                            
                    except Exception as e:
                        prompt_data = future_to_prompt[future]
                        prompt_id = prompt_data.get("prompt_id", "unknown")
                        self.logger.error(f"❌ Future执行失败 {prompt_id}: {e}")
                        
                        # 添加失败记录
                        all_results.append({
                            "prompt_id": prompt_id,
                            "domain": self.domain,
                            "entity": prompt_data.get("entity", "unknown"),
                            "wikidata_uri": prompt_data.get("wikidata_uri", "unknown"),
                            "category": prompt_data.get("category", "unknown"),
                            "behavior_type": prompt_data.get("behavior_type", "unknown"),
                            "original_prompt": prompt_data.get("prompt", ""),
                            "stealth_prompt": None,
                            "stealth_success": False,
                            "reason": f"Future异常: {str(e)}",
                            # 保持原有字段用于兼容性
                            "success": False,
                            "metadata": prompt_data
                        })
        else:
            # 使用进程池 - 适合CPU密集型任务
            with ProcessPoolExecutor(max_workers=self.config.max_workers) as executor:
                # 提交所有任务
                future_to_prompt = {
                    executor.submit(self._process_prompt_in_subprocess, prompt_data): prompt_data
                    for prompt_data in all_prompts
                }
                
                # 收集结果
                for future in as_completed(future_to_prompt):
                    try:
                        result = future.result()
                        all_results.append(result)
                        
                        # 实时进度显示
                        current_count = len(all_results)
                        if current_count % self.config.progress_interval == 0 or current_count == total_prompts:
                            progress_pct = (current_count / total_prompts * 100)
                            self.logger.info(f"📊 进度: {current_count}/{total_prompts} ({progress_pct:.1f}%)")
                        
                        # 定期保存中间结果
                        if len(all_results) % self.config.save_interval == 0:
                            self._save_intermediate_results(all_results)
                            
                    except Exception as e:
                        prompt_data = future_to_prompt[future]
                        prompt_id = prompt_data.get("prompt_id", "unknown")
                        self.logger.error(f"❌ 进程执行失败 {prompt_id}: {e}")
        
        total_time = time.time() - start_time
        success_count = sum(1 for r in all_results if r.get("success", False))
        success_rate = success_count / len(all_results) if all_results else 0
        
        self.logger.info(f"🎉 并行处理完成!")
        self.logger.info(f"   📊 总计: {len(all_results)}/{total_prompts} 个prompts")
        self.logger.info(f"   ✅ 成功: {success_count} 个 ({success_rate:.1%})")
        self.logger.info(f"   ⏰ 总时间: {total_time:.1f}秒 ({total_time/60:.1f}分钟)")
        self.logger.info(f"   🚀 平均速度: {len(all_results)/total_time:.2f} prompts/秒")
        
        # 计算性能提升
        original_estimate = len(all_results) * 105.98  # 原版单prompt时间
        speedup = original_estimate / total_time if total_time > 0 else 1
        self.logger.info(f"   ⚡ 性能提升: {speedup:.1f}x (预估)")
        
        # 保存最终结果
        self.processed_results = all_results
        self._save_final_results(all_results, total_time, speedup)
        
        return success_rate >= 0.3  # 30%成功率认为可接受
    
    def _process_prompt_in_subprocess(self, prompt_data: Dict) -> Dict:
        """在子进程中处理prompt - 用于进程池"""
        # 这个函数需要是顶级函数才能被pickle，所以简化实现
        return self.process_single_prompt_worker(prompt_data)
    
    def _save_intermediate_results(self, results: List[Dict]):
        """保存中间结果"""
        try:
            intermediate_path = self.output_dir / f"step5_parallel_intermediate_{self.domain}.json"
            with open(intermediate_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            
            # 显示进度信息
            if hasattr(self, 'total_prompts') and self.total_prompts > 0:
                progress_pct = (len(results) / self.total_prompts * 100)
                self.logger.info(f"💾 中间结果已保存: {len(results)}/{self.total_prompts} ({progress_pct:.1f}%)")
            else:
                self.logger.info(f"💾 中间结果已保存: {len(results)} 个prompts")
        except Exception as e:
            self.logger.warning(f"⚠️ 保存中间结果失败: {e}")
    
    def _save_final_results(self, results: List[Dict], total_time: float, speedup: float):
        """保存最终结果"""
        try:
            # 保存完整结果
            results_path = self.output_dir / f"step5_parallel_results_{self.domain}.json"
            with open(results_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
            
            # 保存成功案例
            successful_results = [r for r in results if r.get("success", False)]
            if successful_results:
                success_path = self.output_dir / f"step5_parallel_successful_{self.domain}.json"
                with open(success_path, 'w', encoding='utf-8') as f:
                    json.dump(successful_results, f, indent=2, ensure_ascii=False)
            
            # 生成与step6兼容的最终数据集格式
            final_dataset = []
            for result in results:
                final_record = {
                    "prompt_id": result.get("prompt_id"),
                    "domain": result.get("domain", self.domain),
                    "entity": result.get("entity", "unknown"),
                    "wikidata_uri": result.get("wikidata_uri", "unknown"),
                    "category": result.get("category", "unknown"),
                    "behavior_type": result.get("behavior_type", "unknown"),
                    "original_prompt": result.get("original_prompt"),
                    "stealth_prompt": result.get("stealth_prompt"),  # null if failed
                    "stealth_success": result.get("stealth_success", False),
                    # 保持原有字段用于兼容性
                    "success": result.get("success", False),
                    "metadata": result.get("metadata", {})
                }
                
                if result.get("success"):
                    final_record.update({
                        "winning_path": result.get("winning_path"),
                        "iterations": result.get("iterations"),
                        "final_similarity": result.get("final_similarity"),
                        "final_ppl": result.get("final_ppl")
                    })
                else:
                    # 失败记录也要保存所有可用信息
                    final_record["reason"] = result.get("reason")
                    final_record["iterations"] = result.get("iterations", 0)  # 保存尝试的迭代次数
                    
                    # 如果有其他信息也保存
                    if result.get("final_similarity") is not None:
                        final_record["final_similarity"] = result.get("final_similarity")
                    if result.get("final_ppl") is not None:
                        final_record["final_ppl"] = result.get("final_ppl")
                    if result.get("winning_path"):
                        final_record["winning_path"] = result.get("winning_path")
                
                final_dataset.append(final_record)
            
            # 保存最终数据集（与主pipeline格式一致）
            dataset_path = self.output_dir / f"step5_parallel_dataset_{self.domain}.json"
            with open(dataset_path, 'w', encoding='utf-8') as f:
                json.dump(final_dataset, f, indent=2, ensure_ascii=False)
            
            # 保存与step6兼容的成功案例数据集
            successful_stealth_dataset = [record for record in final_dataset if record.get("stealth_success", False)]
            if successful_stealth_dataset:
                stealth_dataset_path = self.output_dir / f"step5_parallel_successful_stealth_dataset_{self.domain}.json"
                with open(stealth_dataset_path, 'w', encoding='utf-8') as f:
                    json.dump(successful_stealth_dataset, f, indent=2, ensure_ascii=False)
            
            # 保存性能报告
            performance_report = {
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "domain": self.domain,
                "algorithm": "原版双路径隐蔽化算法（并行执行）",
                "parallel_config": {
                    "max_workers": self.config.max_workers,
                    "use_threading": self.config.use_threading,
                    "chunk_size": self.config.chunk_size
                },
                "stealth_config": self.stealth_config,  # 原版配置
                "performance_metrics": {
                    "total_prompts": len(results),
                    "successful_prompts": len(successful_results),
                    "success_rate": len(successful_results) / len(results) if results else 0,
                    "total_time_seconds": total_time,
                    "total_time_minutes": total_time / 60,
                    "total_time_hours": total_time / 3600,
                    "average_time_per_prompt": total_time / len(results) if results else 0,
                    "prompts_per_second": len(results) / total_time if total_time > 0 else 0,
                    "performance_speedup": speedup
                },
                "estimated_1000_prompts": {
                    "estimated_time_hours": (total_time / len(results) * 1000 / 3600) if results else 0,
                    "estimated_success_count": int(len(successful_results) / len(results) * 1000) if results else 0
                }
            }
            
            report_path = self.output_dir / f"step5_parallel_performance_report_{self.domain}.json"
            with open(report_path, 'w', encoding='utf-8') as f:
                json.dump(performance_report, f, indent=2, ensure_ascii=False)
            
            self.logger.info(f"💾 结果已保存:")
            self.logger.info(f"   📄 完整结果: {results_path.name}")
            if successful_results:
                self.logger.info(f"   ✅ 成功案例: {success_path.name}")
            self.logger.info(f"   📊 最终数据集: {dataset_path.name}")
            if successful_stealth_dataset:
                self.logger.info(f"   🎯 成功隐蔽化数据集: {stealth_dataset_path.name}")
            self.logger.info(f"   📈 性能报告: {report_path.name}")
            
        except Exception as e:
            self.logger.error(f"❌ 保存结果失败: {e}")

def main():
    parser = argparse.ArgumentParser(description="Step 5: 并行优化的批量隐蔽化改写")
    parser.add_argument("--domain", default="medicine", help="处理领域")
    parser.add_argument("--workers", type=int, default=10, help="并发worker数")
    parser.add_argument("--max-prompts", type=int, help="限制处理的prompt数量")
    parser.add_argument("--use-processes", action="store_true", help="使用进程池而非线程池")
    parser.add_argument("--progress-interval", type=int, default=10, help="进度报告间隔")
    parser.add_argument("--save-interval", type=int, default=50, help="中间结果保存间隔")
    
    args = parser.parse_args()
    
    # 加载环境配置
    from real_pipeline.step5_implicit_jailbreak import load_env_config, check_api_config
    load_env_config()
    
    if not check_api_config():
        sys.exit(1)
    
    # 创建并行配置
    config = ParallelConfig(
        max_workers=args.workers,
        use_threading=not args.use_processes,
        progress_interval=args.progress_interval,
        save_interval=args.save_interval
    )
    
    processor = ParallelDualPathProcessor(domain=args.domain, config=config)
    
    print(f"\n🚀 开始并行隐蔽化处理...")
    print(f"📋 算法: 原版双路径隐蔽化算法 (完全一致)")
    print(f"👥 并发数: {args.workers}")
    print(f"🧮 执行方式: {'线程池' if not args.use_processes else '进程池'}")
    
    success = processor.run_parallel_processing(max_prompts=args.max_prompts)
    
    if success:
        print(f"\n🎉 并行处理完成!")
        results = processor.processed_results
        successful = [r for r in results if r.get("success", False)]
        print(f"✅ 成功处理: {len(successful)}/{len(results)} 个prompts")
        
        if results:
            processing_times = [r.get("processing_time", 0) for r in results if r.get("processing_time")]
            if processing_times:
                avg_time = sum(processing_times) / len(processing_times)
                print(f"⚡ 平均处理时间: {avg_time:.2f}秒/prompt")
                print(f"🚀 预估1000个prompts需要: {avg_time * 1000 / 3600:.1f} 小时")
                
                # 计算与原版的性能对比
                original_time = 105.98  # 原版单prompt时间
                speedup = original_time / avg_time if avg_time > 0 else 1
                print(f"📈 性能提升: {speedup:.1f}x")
        
        # 显示路径统计
        if successful:
            winner_stats = {}
            iteration_stats = []
            for result in successful:
                winner = result.get("winning_path", "none")
                winner_stats[winner] = winner_stats.get(winner, 0) + 1
                iterations = result.get("iterations", 0)
                if iterations > 0:
                    iteration_stats.append(iterations)
            
            print(f"🏆 获胜路径统计: {winner_stats}")
            if iteration_stats:
                avg_iterations = sum(iteration_stats) / len(iteration_stats)
                print(f"🔄 平均迭代次数: {avg_iterations:.1f}")
    else:
        print(f"\n❌ 并行处理失败")
    
    sys.exit(0 if success else 1)

if __name__ == "__main__":
    main()