from concurrent.futures import thread
import numpy as np
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from openai import OpenAI
import pickle
import json
import os
from datetime import datetime
from tqdm import tqdm

import faiss
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.documents import Document
import threading
# from memory_defender import process_single_benign_sample

import torch

# 假设外部已经初始化了 client 和 embedding model
# 延迟加载：不在模块导入时加载模型，而是在首次使用时加载
_model_instance = None

def get_embedding_model():
    """延迟加载 SentenceTransformer 模型"""
    global _model_instance
    if _model_instance is None:
        import time
        import torch
        # 检测可用的设备
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"[{time.strftime('%H:%M:%S')}] 首次加载 SentenceTransformer 模型 (device: {device})...")
        start = time.time()
        _model_instance = SentenceTransformer('all-MiniLM-L6-v2', device=device)
        print(f"[{time.strftime('%H:%M:%S')}] SentenceTransformer 模型加载完成 (耗时: {time.time() - start:.2f}s, device: {device})")
    return _model_instance

# 为了向后兼容，保留 model 变量，但改为延迟加载
class _LazyModel:
    def __getattr__(self, name):
        return getattr(get_embedding_model(), name)
model = _LazyModel()

def process_single_benign_sample(record):
    """
    Normalize raw benign records into structured fields.
    Returns a dict containing:
        - intent: the user query string
        - response: the assistant outcome
        - pattern: tool-chain string
        - combined: optional human-readable string (for debugging/logs)
    """
    
    user_query = ""
    assistant_response = ""
    
    if 'messages' in record:
        messages = record['messages']
        user_query = next((m.get('content', "") for m in messages if m.get('role') == 'user'), "")
        assistant_response = next((m.get('content', "") for m in reversed(messages) if m.get('role') == 'assistant'), "")
    else:
        user_query = record.get('query', "")
        assistant_response = record.get('response', "")

    if not user_query or not assistant_response:
        return None

    raw_pattern = record.get('pattern')
    final_pattern_str = ""
    
    if raw_pattern and isinstance(raw_pattern, list):
        final_pattern_str = " -> ".join(raw_pattern)
    elif raw_pattern and isinstance(raw_pattern, str):
        final_pattern_str = raw_pattern
    else:
        tool_chain = []
        if 'messages' in record:
            for msg in record['messages']:
                if msg.get('tool_calls'):
                    for tc in msg['tool_calls']:
                        fname = tc.get('function', {}).get('name')
                        if fname:
                            tool_chain.append(fname)
        final_pattern_str = " -> ".join(tool_chain) if tool_chain else "System.General_Process"

    combined_str = f"User Intent: {user_query}\nPattern: {final_pattern_str}\nAssistant Outcome: {assistant_response}"

    return {
        "intent": user_query,
        "response": assistant_response,
        "pattern": final_pattern_str,
        "combined": combined_str,
        "source": record
    }


class TreeNode:
    def __init__(self, label):
        self.label = label  # 节点标签 (e.g., "social_engineering", "malware")
        self.children = []  
        self.value = None   
        
        # [保留] 用于路由检索的中心向量
        self.center_embedding = None 
        # 投影后的中心向量（如果启用了 Safety Projection，在 inject_benign_dataset 时预计算）
        self.projected_center_embedding = None
        self.benign_exemplars = []
        
        # 防御策略（用于存储 LLM 生成的防御规则）
        self.defense_strategy = None
        
        # 良性边界规则（用于存储从良性样本中提取的边界规则）
        self.benign_boundary_rule = None
        
        # 良性样本的中心向量和计数（用于边界校准）
        self.benign_center_embedding = None
        # 投影后的良性中心向量（如果启用了 Safety Projection，在 inject_benign_dataset 时预计算）
        self.projected_benign_center_embedding = None
        self.benign_count = 0 

    def add_child(self, child_node):
        self.children.append(child_node)
        
    def set_value(self, value):
        self.value = value
        if 'embedding' in value:
            self.center_embedding = value['embedding']
            
# 放在 RiskTree class 定义的上面
class SentenceTransformerAdapter:
    """
    [Adapter] 将 SentenceTransformer 包装成 LangChain 兼容的 Embedding 接口
    核心作用：让 FAISS 可以直接复用 self.embedding_model，而不用加载两份模型
    """
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts):
        # 批量 embedding：FAISS 建库时调用
        # normalize_embeddings=True 对余弦相似度检索至关重要
        embeddings = self.model.encode(texts, normalize_embeddings=True)
        return embeddings.tolist() # LangChain 要求返回 list 格式

    def embed_query(self, text):
        # 单条 embedding：检索 query 时调用
        embedding = self.model.encode(text, normalize_embeddings=True)
        return embedding.tolist()

    def __call__(self, texts):
        # 根据你的模型实际 API，可能是 encode 或 embed_documents
        # 如果输入是单个字符串
        if isinstance(texts, str):
            return self.model.encode(texts)
        # 如果输入是列表
        return self.model.encode(texts)

class RiskTree:
    def __init__(self, threshold=0.3, attack_similarity_threshold=0.5, k=5, score_log_file="./logs/score_log.jsonl", safety_projector_path="./models/safety_projector.pth", enable_safety_projection=True, use_single_branch_mode=False, retrieval_similarity_threshold=0.4, retrieval_adaptive_threshold=False, llm_port=8030, model_name="qwen-72b"): 
        # threshold: 规则信息增益阈值，0.3表示信息增益>0.3才merge
        # attack_similarity_threshold: attack_prompt相似度阈值，0.5表示相似度>=0.5才进入规则判断
        # llm_port: LLM API 端口号（默认 8030）
        # model_name: LLM 模型名称（默认 "qwen-72b"）
        # 全局锁（保留用于需要全局同步的操作）
        self.lock = threading.Lock()
        # 为每个category分配独立的锁，支持并行处理不同category
        self.category_locks = {}
        self._category_locks_lock = threading.Lock()  # 保护category_locks字典本身的锁
        
        self.root = TreeNode("root") 
        self.threshold = threshold  # 规则信息增益阈值
        self.attack_similarity_threshold = attack_similarity_threshold  # attack_prompt相似度阈值
        self.retrieval_similarity_threshold = retrieval_similarity_threshold  # 检索相似度阈值（降低到0.4）
        self.retrieval_adaptive_threshold = retrieval_adaptive_threshold  # 自适应阈值开关（禁用）
        # LLM 客户端和模型名称
        self.llm_port = llm_port
        self.model_name = model_name
        self.client = OpenAI(base_url=f"http://127.0.0.1:{llm_port}/v1", api_key="EMPTY")
        # 延迟加载 embedding 模型（首次使用时才加载）
        self.embedding_model = get_embedding_model()
        self.embedding_func = SentenceTransformerAdapter(self.embedding_model)

        self.k = k
        # 用于记录分数日志的文件路径（如果为 None，则不记录）
        self.score_log_file = score_log_file
        self._score_log_count = 0
        
        # 统计废弃的规则数量
        self.discarded_rules_count = 0
        
        # Safety Projector 控制参数（用于消融实验）
        self.enable_safety_projection = enable_safety_projection
        
        # 单分支模式控制参数（用于消融实验：将所有分支统一为 AMBIGUOUS）
        self.use_single_branch_mode = use_single_branch_mode
        
        # 加载 Safety Projector（如果提供路径且启用）
        self.safety_projector = safety_projector_path
        self.use_safety_projection = False
        self.device = None  # 存储设备信息

        if enable_safety_projection and safety_projector_path and os.path.exists(safety_projector_path):
            try:
                from SafetyProjector import SafetyProjector
                import torch
                
                # 检测可用的设备
                self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                print(f"✓ 检测到设备: {self.device}")
                
                # 获取 embedding 维度
                input_dim = self.embedding_model.get_sentence_embedding_dimension()
                
                # 加载模型到指定设备
                checkpoint = torch.load(safety_projector_path, map_location=self.device)
                
                # 处理 checkpoint 结构（兼容多种保存格式）
                if isinstance(checkpoint, dict):
                    # 标准格式：包含 'model_state_dict' 键
                    if 'model_state_dict' in checkpoint:
                        state_dict = checkpoint['model_state_dict']
                    # 旧格式：直接就是 state_dict
                    elif all(key.startswith(('net.', 'prototypes', 'temperature')) or '.' in key for key in checkpoint.keys()):
                        state_dict = checkpoint
                    else:
                        # 尝试获取 model_state_dict，如果不存在就使用整个 checkpoint
                        state_dict = checkpoint.get('model_state_dict', checkpoint)
                else:
                    # 如果不是字典，可能直接是 state_dict（不太可能，但兼容处理）
                    state_dict = checkpoint
                
                # 检查 checkpoint 中是否有 prototypes 和 temperature（兼容新旧版本）
                has_prototypes = 'prototypes' in state_dict
                has_temperature = 'temperature' in state_dict
                
                # 创建模型时直接传入 device，确保所有参数都在正确设备上
                self.safety_projector = SafetyProjector(input_dim=input_dim, device=self.device)
                
                # 加载权重，使用 strict=False 以兼容旧版本模型
                missing_keys, unexpected_keys = self.safety_projector.load_state_dict(
                    state_dict, strict=False
                )
                
                # 如果 checkpoint 中没有 prototypes 和 temperature，说明是旧版本模型
                # 需要从 checkpoint 顶层键加载（如果存在），或者使用默认初始化值
                if not has_prototypes or not has_temperature:
                    print(f"⚠️ Warning: 检测到旧版本模型（缺少 prototypes 或 temperature）")
                    
                    # 尝试从 checkpoint 顶层键加载（如果保存时是单独保存的）
                    if isinstance(checkpoint, dict):
                        if 'prototypes' in checkpoint and not has_prototypes:
                            print("  → 从 checkpoint 顶层加载 prototypes")
                            try:
                                with torch.no_grad():
                                    proto_tensor = checkpoint['prototypes']
                                    if isinstance(proto_tensor, torch.Tensor):
                                        # 确保形状和设备匹配
                                        if proto_tensor.shape == self.safety_projector.prototypes.shape:
                                            self.safety_projector.prototypes.copy_(proto_tensor.to(self.device))
                                            has_prototypes = True
                                        else:
                                            print(f"    ⚠️ prototypes 形状不匹配: {proto_tensor.shape} vs {self.safety_projector.prototypes.shape}")
                            except Exception as e:
                                print(f"    ⚠️ 加载 prototypes 失败: {e}")
                        
                        if 'temperature' in checkpoint and not has_temperature:
                            print("  → 从 checkpoint 顶层加载 temperature")
                            try:
                                with torch.no_grad():
                                    temp_tensor = checkpoint['temperature']
                                    if isinstance(temp_tensor, torch.Tensor):
                                        self.safety_projector.temperature.copy_(temp_tensor.to(self.device))
                                        has_temperature = True
                                    elif isinstance(temp_tensor, (int, float)):
                                        self.safety_projector.temperature.fill_(float(temp_tensor))
                                        has_temperature = True
                            except Exception as e:
                                print(f"    ⚠️ 加载 temperature 失败: {e}")
                    
                    # 如果仍然没有，使用默认值（已在新模型初始化时设置，这里只是提示）
                    if not has_prototypes or not has_temperature:
                        print(f"  → 使用默认初始化的 prototypes 和 temperature")
                        print(f"     (prototypes: 随机初始化, temperature: {self.safety_projector.temperature.item():.3f})")
                        print(f"     注意：使用随机初始化的参数可能影响模型性能，建议重新训练模型")
                
                # 打印加载信息
                if missing_keys:
                    print(f"⚠️ Warning: Missing keys when loading Safety Projector: {missing_keys}")
                if unexpected_keys:
                    print(f"⚠️ Warning: Unexpected keys when loading Safety Projector: {unexpected_keys}")
                
                # 确保模型在正确设备上（load_state_dict 后可能需要重新确认）
                self.safety_projector.to(self.device)
                self.safety_projector.eval()
                self.use_safety_projection = True
                print(f"✓ Safety Projector loaded from {safety_projector_path} (device: {self.device})")
            
            except Exception as e:
                print(f"⚠️ Warning: Failed to load Safety Projector: {e}")
                import traceback
                traceback.print_exc()
                print("   Continuing without safety projection...")
                self.use_safety_projection = False
                self.safety_projector = None

        elif safety_projector_path:
            print(f"⚠️ Warning: Safety Projector path not found: {safety_projector_path}")
            print("   Continuing without safety projection...")

        benign_data_path = '/path/to/agentharm/agent_align_data_v3.json'
        print(f"[RiskTree] Initializing Static RAG from dataset...")
        self.defense_vector_store = self._init_static_rag(benign_data_path)


    def _init_static_rag(self, benign_data_path, cache_dir="./cache/faiss_index"):
        """
        [Modified] 初始化良性数据的向量库 (FAISS)，支持缓存
        - Embedding Source: 仅 User Query (Intent) -> 保证检索语义对齐
        - Return Content:   Query + Pattern        -> 仅作为参考范式，不包含具体回复
        - Cache: 如果数据文件未变化，直接加载缓存的 FAISS 索引
        """
        import json
        import os
        import numpy as np
        import faiss
        import hashlib
        import pickle
        from langchain_community.vectorstores import FAISS
        from langchain_community.docstore.in_memory import InMemoryDocstore
        from langchain_core.documents import Document

        if not os.path.exists(benign_data_path):
            print(f"⚠️ Warning: Benign dataset not found at {benign_data_path}")
            return None

        # 计算数据文件的 hash 用于缓存验证
        try:
            with open(benign_data_path, 'rb') as f:
                file_hash = hashlib.md5(f.read()).hexdigest()
        except Exception as e:
            print(f"⚠️ Failed to compute file hash: {e}")
            file_hash = None

        # 检查缓存目录
        os.makedirs(cache_dir, exist_ok=True)
        cache_index_path = os.path.join(cache_dir, "benign_faiss.index")
        cache_metadata_path = os.path.join(cache_dir, "benign_metadata.pkl")
        cache_hash_path = os.path.join(cache_dir, "benign_data_hash.txt")

        # 尝试加载缓存
        if file_hash and os.path.exists(cache_index_path) and os.path.exists(cache_metadata_path) and os.path.exists(cache_hash_path):
            try:
                # 检查 hash 是否匹配
                with open(cache_hash_path, 'r') as f:
                    cached_hash = f.read().strip()
                
                if cached_hash == file_hash:
                    print(f"[RAG Init] Loading cached FAISS index from {cache_dir}...")
                    
                    # 加载 FAISS 索引
                    index = faiss.read_index(cache_index_path)
                    
                    # 加载 metadata (docstore 和 index_to_docstore_id)
                    with open(cache_metadata_path, 'rb') as f:
                        cache_data = pickle.load(f)
                        docstore = cache_data['docstore']
                        index_to_docstore_id = cache_data['index_to_docstore_id']
                    
                    # 重建 FAISS 包装器
                    vector_store = FAISS(
                        embedding_function=self.embedding_func,
                        index=index,
                        docstore=docstore,
                        index_to_docstore_id=index_to_docstore_id
                    )
                    
                    print(f"✓ Loaded cached FAISS index ({index.ntotal} vectors)")
                    return vector_store
                else:
                    print(f"[RAG Init] Data file changed, rebuilding index...")
            except Exception as e:
                print(f"⚠️ Failed to load cache: {e}, rebuilding index...")

        # 缓存不存在或已过期，重新构建
        print(f"[RAG Init] Building FAISS index from {benign_data_path}...")
        try:
            with open(benign_data_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"⚠️ Failed to load JSON: {e}")
            return None

        # -------------------------------------------------------------
        # 1. 数据预处理
        # -------------------------------------------------------------
        documents = []
        embedding_texts = []  # 暂存用于计算向量的文本 (Query Only)
        
        valid_count = 0
        
        for raw_item in data:
            # 只处理良性样本
            raw_id = raw_item.get('id', '')
            process_item = process_single_benign_sample(raw_item)
            if raw_item.get('category') != 'benign' or not process_item:
                continue
            
            # 提取字段 (兼容不同的字段命名)
            # 假设数据清洗函数 process_single_benign_sample 已被内联或不再需要复杂逻辑
            intent = process_item.get('intent') 
            pattern = process_item.get('pattern', [])
            response = process_item.get('response', '')  # 提取assistant outcome

            # [关键修改 A] 构造 Page Content (LLM 看到的文本)
            # 格式: Query + Pattern (移除了 Outcome/Response)
            page_content = f"User Intent: {intent}\nPattern: {pattern}"

            # [关键修改 B] 构造 Embedding Text (生成向量的文本)
            # 仅使用 Query，保证检索时的语义匹配度最高
            embed_source = intent

            # 创建 Document 对象
            doc = Document(
                page_content=page_content,
                metadata={
                    "id": raw_id,
                    "pattern": pattern,
                    "query": intent,  # 元数据里也存一份方便调试
                    "intent": intent,  # 添加intent字段以兼容检索代码
                    "response": response  # 保存assistant outcome以便检索时使用
                }
            )
            
            documents.append(doc)
            embedding_texts.append(embed_source)
            valid_count += 1

        if not documents:
            print("⚠️ Warning: No valid documents found.")
            return None

        print(f"✓ Prepared {len(documents)} samples.")
        print(f"  - Vector Source: User Query ONLY")
        print(f"  - Content View : User Query + Pattern")

        # -------------------------------------------------------------
        # 2. 手动构建 FAISS 索引 (为了实现 Embedding 和 Content 的解耦)
        # -------------------------------------------------------------
        try:
            # A. 批量计算向量
            print(f"Computing embeddings for {len(embedding_texts)} items...")
            
            # 使用 self.embedding_func (在 __init__ 里初始化的 LangChain Adapter)
            # embed_documents 返回 list[list[float]]
            raw_embeddings = self.embedding_func.embed_documents(embedding_texts)
            
            # 转为 float32 的 numpy 数组
            embeddings_matrix = np.array(raw_embeddings, dtype=np.float32)
            
            # B. 创建 FAISS 索引
            dimension = embeddings_matrix.shape[1]
            index = faiss.IndexFlatL2(dimension) # 使用 L2 距离
            
            # C. 归一化 (Normalize)
            # 因为我们在 retrieve_query 里用的是 Cosine 相似度逻辑
            # L2 Distance on Normalized Vectors <==> Cosine Similarity
            faiss.normalize_L2(embeddings_matrix)
            
            # D. 添加向量到索引
            index.add(embeddings_matrix)
            
            # E. 构建 Docstore 映射 (Index ID -> Document)
            docstore = InMemoryDocstore()
            index_to_docstore_id = {}
            
            for i, doc in enumerate(documents):
                # 生成唯一 ID
                doc_id = str(doc.metadata.get('id'))
                docstore.add({doc_id: doc})
                index_to_docstore_id[i] = doc_id
            
            # F. 实例化 LangChain FAISS 包装器
            vector_store = FAISS(
                embedding_function=self.embedding_func, # 传入函数用于 query 时的 embedding
                index=index,
                docstore=docstore,
                index_to_docstore_id=index_to_docstore_id
            )
            
            # G. 保存缓存（如果数据文件 hash 可用）
            if file_hash:
                try:
                    print(f"[RAG Init] Saving FAISS index cache to {cache_dir}...")
                    # 保存 FAISS 索引
                    faiss.write_index(index, cache_index_path)
                    
                    # 保存 metadata (docstore 和 index_to_docstore_id)
                    # 注意：docstore 需要特殊处理，因为它包含 Document 对象
                    cache_data = {
                        'docstore': docstore,
                        'index_to_docstore_id': index_to_docstore_id
                    }
                    with open(cache_metadata_path, 'wb') as f:
                        pickle.dump(cache_data, f)
                    
                    # 保存 hash
                    with open(cache_hash_path, 'w') as f:
                        f.write(file_hash)
                    
                    print(f"✓ FAISS index cached successfully")
                except Exception as e:
                    print(f"⚠️ Failed to save cache: {e} (continuing without cache)")
            
            print(f"✓ FAISS Index built successfully. (Dim: {dimension})")
            return vector_store

        except Exception as e:
            print(f"⚠️ FAISS Init Error: {e}")
            import traceback
            traceback.print_exc()
            return None


    def retrieve_for_verifier(self, query_text, k=3):
        """
        [Verifier 专用检索]
        返回构造好的 Messages 列表：
        1. System: 固定的身份设定。
        2. User:  攻击 Prompt + 检索到的良性参考信息。
        """
        # 1. System Message (最前面)
        system_content = (
            "You are an AI Safety Evaluator.\n"
            "Your task is to analyze the user's query and the provided reference safety patterns."
        )
        messages = [{"role": "system", "content": system_content}]

        # 2. RAG 检索良性样本
        ref_text = ""
        if self.defense_vector_store:
            try:
                results = self.defense_vector_store.similarity_search(query_text, k=k)
                if results:
                    ref_text = "\n\n[Standard Safety References]:\n"
                    for i, doc in enumerate(results):
                        # 清洗换行符，保持紧凑
                        content = doc.page_content.replace('\n', ' ').strip()
                        ref_text += f"{i+1}. {content}\n"
            except Exception as e:
                print(f"⚠️ [Verifier Retrieve Error] {e}")

        # 3. 拼接 Query 和 References 到 User Message
        # 格式：User Query -> References
        user_content = f"User Query: {query_text}{ref_text}"
        
        messages.append({"role": "user", "content": user_content})
        
        return messages


    def _get_category_lock(self, category_label):
        """
        线程安全地获取或创建指定category的锁
        不同category之间可以并行处理，因为它们在树结构上是独立的
        """
        with self._category_locks_lock:
            if category_label not in self.category_locks:
                self.category_locks[category_label] = threading.Lock()
            return self.category_locks[category_label]

    def add_node(self, parent_label, child_label, value=None):
        """
        [Modified] 核心演化逻辑：基于规则信息增益决策
        value 字典必须包含: 
        - 'embedding': np.array (attack_prompt 的 embedding)
        - 'attack_prompt': str
        - 'defense_response': str (规则文本)
        
        使用category级别的锁，允许不同category并行处理
        LLM调用在锁外执行，避免长时间持有锁
        """
        # 获取对应category的锁（parent_label就是category）
        category_lock = self._get_category_lock(parent_label)
        
        # 在锁外检查是否需要生成规则（避免在锁内进行LLM调用）
        rule_dict = value.get('rule_dict', {})
        new_rule_text = value.get('defense_response', '')
        
        if rule_dict and isinstance(rule_dict, dict):
            new_rule_text = rule_dict.get('harmful_rule', '') or rule_dict.get('rule_text', '')
        elif new_rule_text:
            import re
            if 'however' in new_rule_text.lower() or 'allow' in new_rule_text.lower():
                parts = re.split(r'\s+however\s+|\s+but\s+allow\s+', new_rule_text, flags=re.IGNORECASE)
                new_rule_text = parts[0].strip() if parts else new_rule_text
            
            if new_rule_text.startswith("Refuse requests related to:") and len(new_rule_text) < 100:
                new_rule_text = ""
        
        # 如果没有规则，在锁外生成规则（第1次LLM调用）
        if not new_rule_text:
            print(f">>> No Rule Text: Generating rule first (1st LLM call, equivalent to add_new_leaf)")
            query_text = value.get('attack_prompt', '')
            if query_text:
                harmful_list = [query_text]
                benign_list = []
                if self.defense_vector_store:
                    try:
                        results = self.defense_vector_store.similarity_search(query_text, k=3)
                        for doc in results:
                            intent = doc.metadata.get("query", "") or doc.metadata.get("intent", "")
                            if intent and intent.strip():
                                benign_list.append(intent.strip())
                    except Exception:
                        pass
                
                # 生成规则（第1次LLM调用，在锁外）
                boundary_rule = self.generate_boundary_rule(harmful_list, benign_list)
                if boundary_rule:
                    rule_text = boundary_rule.get('rule_text', '')
                    harmful_rule = boundary_rule.get('harmful_rule', '')
                    benign_rule = boundary_rule.get('benign_rule', '') or boundary_rule.get('exemptions', '')
                    
                    if not rule_text:
                        if harmful_rule and benign_rule:
                            rule_text = f"{harmful_rule}, however, {benign_rule}"
                        elif harmful_rule:
                            rule_text = harmful_rule
                        else:
                            rule_text = benign_rule
                    
                    # 更新 value 中的规则信息
                    value['rule_dict'] = {
                        'harmful_rule': harmful_rule or rule_text,
                        'benign_rule': benign_rule,
                        'rule_text': rule_text
                    }
                    value['defense_response'] = rule_text
                    value['benign_boundary_rule'] = boundary_rule
                    new_rule_text = harmful_rule or rule_text
                    print(f">>> Generated rule (1st LLM call), continuing with merge decision")
                else:
                    print(f">>> Failed to generate rule, will use fallback in lock")
                    # 如果生成失败，在锁内使用 fallback
            else:
                print(f">>> No attack_prompt, will use fallback in lock")
        
        # 在锁内进行所有树结构操作
        with category_lock:
            parent_node = self.find_node_by_label(self.root, parent_label)
            
            # 1. 如果没有找到父节点 (Category Level)，创建新的
            if not parent_node:
                parent_node = TreeNode(parent_label)
                self.root.add_child(parent_node)

            # 2. 获取当前父节点下所有子节点的 Embedding
            existing_nodes = parent_node.children
            
            if not existing_nodes:
                # 冷启动：直接添加第一个叶子/Cluster
                self.add_new_leaf(parent_node, value, is_cluster_root=True)
                return

            new_embed = value['embedding']  # attack_prompt 的 embedding（用于找最相似的 cluster）
            
            # 3. 先找到最相似的 cluster（基于 attack_prompt 相似度）
            embeds = [node.center_embedding for node in existing_nodes if node.center_embedding is not None]
            
            if not embeds:
                # 如果没有有效的 embedding，直接新建 cluster
                self.add_new_leaf(parent_node, value, is_cluster_root=True)
                return
            
            sims = [self._cosine_similarity(embed, new_embed) for embed in embeds]
            max_sim = max(sims) if sims else 0.0
            
            # 4. 如果 attack_prompt 相似度太低，直接新建 cluster
            if max_sim < self.attack_similarity_threshold:
                print(f">>> Low Attack Similarity ({max_sim:.3f} < {self.attack_similarity_threshold:.3f}): Creating New Cluster")
                self.add_new_leaf(parent_node, value, is_cluster_root=True)
                return
            
            # 5. 找到最相似的 cluster，计算规则信息增益
            best_idx = sims.index(max_sim)
            best_cluster = existing_nodes[best_idx]
            
            # 6. 如果没有规则（生成失败的情况），使用 fallback
            if not new_rule_text:
                print(f">>> Using fallback strategy")
                self.merge_action(best_cluster, value, skip_llm=True)
                return
            
            # 7. 计算规则信息增益（使用第一次生成的 harmful_rule）
            info_gain = self._calculate_rule_entropy_gain(best_cluster, new_rule_text)
            
            print(f"Sample: {child_label} | Attack Sim: {max_sim:.3f} | Rule Info Gain: {info_gain:.4f} | Attack Sim Threshold: {self.attack_similarity_threshold:.3f} | Rule Info Gain Threshold: {self.threshold:.3f}")
            
            # 8. 决策：如果信息增益高（规则带来新信息），直接添加叶子节点；否则调用LLM合并规则
            if info_gain > self.threshold:
                print(f">>> High Rule Info Gain ({info_gain:.4f} > {self.threshold:.3f}): Rules bring new information, Adding leaf node to Cluster")
                # 信息增益高：规则带来新信息，直接添加叶子节点，不调用LLM合并规则
                self.merge_action(best_cluster, value, skip_llm=True)
            else:
                print(f">>> Low Rule Info Gain ({info_gain:.4f} <= {self.threshold:.3f}): Rules are similar, Merging rules with LLM")
                # 信息增益低：规则相似，需要调用LLM合并规则
                # 为了将LLM调用移到锁外，我们需要先释放锁，调用LLM，然后重新获取锁
                # 但为了简化，我们暂时在锁内调用merge_action，merge_action内部的LLM调用时间相对较短
                # 注意：不同category之间仍然可以并行处理，因为使用的是category级别的锁
                self.merge_action(best_cluster, value, skip_llm=False)
    

    def find_node_by_label(self, node, label):
        if node.label == label:
            return node
        for child in node.children:
            result = self.find_node_by_label(child, label)
            if result:
                return result
        return None

    def choose_similar(self, child_nodes, existing_embeddings, new_embedding):
        """
        选择最相似的节点（用于合并）
        """
        best_match = None
        highest_sim = -1 

        for i, embed in enumerate(existing_embeddings):
            sim = self._cosine_similarity(embed, new_embedding)
            if sim > highest_sim:
                highest_sim = sim
                best_match = child_nodes[i]

        return best_match

    def _generate_benign_rule_for_leaf(self, leaf_node, attack_prompt, defense_response):
        """
        为叶子节点生成 benign_boundary_rule
        
        Args:
            leaf_node: 叶子节点
            attack_prompt: 攻击样本
            defense_response: 防御规则（harmful rule）
        
        Returns:
            bool: 是否成功生成规则
        """
        if not self.defense_vector_store or not attack_prompt or not attack_prompt.strip():
            return False
        
        try:
            # 检索相关良性样本
            results = self.defense_vector_store.similarity_search(attack_prompt, k=3)
            
            if not results:
                return False
            
            # 收集有害样本（使用 defense_response 作为 harmful rule）
            harmful_list = [defense_response] if defense_response and defense_response.strip() else [attack_prompt]
            
            # 收集良性样本
            benign_list = []
            for doc in results:
                intent = doc.metadata.get("query", "") or doc.metadata.get("intent", "")
                if intent and intent.strip():
                    benign_list.append(intent.strip())
            
            # 生成边界规则
            if harmful_list and benign_list:
                boundary_rule = self.generate_boundary_rule(harmful_list, benign_list)
                if boundary_rule:
                    leaf_node.benign_boundary_rule = boundary_rule
                    print(f"   ✓ Generated benign rule for leaf node {leaf_node.label}")
                    return True
        
        except Exception as e:
            print(f"   ⚠️ Failed to generate benign rule for leaf {leaf_node.label}: {e}")
        
        return False

    def merge_action(self, similar_node, value, updated_strategy=None, skip_llm=False):
        """
        [Optimized] 合并操作：更新节点结构、中心向量及防御策略。
        """
        # 1. 添加该样本到节点的 children (作为 Exemplar)
        new_leaf = TreeNode(value['id'])
        new_leaf.set_value(value)
        similar_node.add_child(new_leaf)
        
        # 2. 更新节点的中心 Embedding (使用之前定义的从子节点重算逻辑)
        self.update_node_center_from_children(similar_node)
        
        # =========================================================
        # 3. 更新防御策略 (Strategy Evolution)
        # =========================================================
        
        # Case A: 外部已传入计算好的策略 (最佳实践：锁外计算，锁内更新)
        if updated_strategy:
            similar_node.defense_strategy = updated_strategy
            if not similar_node.value: similar_node.value = {}
            similar_node.value['defense_response'] = updated_strategy

        # Case B: 需要调用 LLM 合并规则 (未跳过 LLM)
        elif not skip_llm:
            try:
                # 确保 value 中已有规则（规则应该在 add_new_leaf 中生成）
                # 如果没有规则，使用 fallback 策略
                defense_response = value.get('defense_response', '')
                rule_dict = value.get('rule_dict', {})
                
                if not rule_dict and (not defense_response or (defense_response.startswith("Refuse requests related to:") and len(defense_response) < 100)):
                    # 没有规则或只有默认规则：不应该进入 merge_action，应该先调用 add_new_leaf
                    print(f"⚠️ [Merge Action] No valid rule found in value, using fallback strategy")
                    self._handle_fallback_strategy(similar_node, value)
                    return
                
                # 调用 LLM 合并现有规则和新规则
                merged_rule = self._merge_rules_with_llm(similar_node, value)
                
                if merged_rule:
                    harmful_rule = merged_rule.get('harmful_rule', '')
                    benign_rule = merged_rule.get('benign_rule', '')
                    rule_text = merged_rule.get('rule_text', '')
                    
                    # 如果没有 rule_text，从 harmful_rule 和 benign_rule 组合
                    if not rule_text:
                        if harmful_rule and benign_rule:
                            rule_text = f"{harmful_rule}, however, {benign_rule}"
                        elif harmful_rule:
                            rule_text = harmful_rule
                        else:
                            rule_text = benign_rule
                    
                    # 设置 defense_strategy（使用完整的 rule_text）
                    if rule_text:
                        if len(rule_text) > 500:
                            rule_text = rule_text[:500] + "..."
                        
                        similar_node.defense_strategy = rule_text
                        if not similar_node.value: similar_node.value = {}
                        similar_node.value['defense_response'] = rule_text
                        
                        # 保存规则字典（用于后续信息增益计算）
                        similar_node.value['rule_dict'] = {
                            'harmful_rule': harmful_rule,
                            'benign_rule': benign_rule,
                            'rule_text': rule_text
                        }
                        
                        # 设置 cluster 级别的 benign_boundary_rule（保存完整的 JSON）
                        similar_node.benign_boundary_rule = {
                            'harmful_rule': harmful_rule,
                            'benign_rule': benign_rule,
                            'rule_text': rule_text,
                            'exemptions': benign_rule  # 兼容旧格式
                        }
                        
                        # 也为新叶子节点设置规则
                        new_leaf.benign_boundary_rule = similar_node.benign_boundary_rule
                        new_leaf.value['defense_response'] = rule_text
                        new_leaf.value['rule_dict'] = similar_node.value['rule_dict']
                        
                        print(f"✓ [Merged Rule] Merged rules with LLM for Node {similar_node.label[-6:]}")
                    else:
                        raise ValueError("Empty rule_text from merged rule")
                else:
                    raise ValueError("Failed to merge rules with LLM")

            except Exception as e:
                print(f"⚠️ [Merge Error] LLM merge failed: {e}")
                import traceback
                traceback.print_exc()
                self._handle_fallback_strategy(similar_node, value)

        # Case C: 跳过 LLM (信息增益高，直接添加叶子节点，不更新 cluster 规则)
        else:
            # 信息增益高时，直接添加叶子节点，使用新样本的规则
            # 不需要更新 cluster 的规则，因为新规则带来了新信息
            if value.get('rule_dict'):
                # 如果新样本有规则，为新叶子节点设置规则
                new_leaf.value['rule_dict'] = value['rule_dict']
                rule_text = value['rule_dict'].get('rule_text', '')
                new_leaf.value['defense_response'] = rule_text  # 保持向后兼容
                if 'benign_boundary_rule' in value:
                    new_leaf.benign_boundary_rule = value['benign_boundary_rule']
            elif value.get('defense_response'):
                # 向后兼容：如果只有 defense_response，直接使用
                new_leaf.value['defense_response'] = value['defense_response']
            
            print(f"✓ [Skip LLM] Added leaf node with new rule (high info gain)")


    def _merge_rules_with_llm(self, existing_node, new_value):
        """
        使用 LLM 合并现有 cluster 的规则和新样本的规则
        
        Args:
            existing_node: 现有的 cluster 节点
            new_value: 新样本的 value 字典
            
        Returns:
            dict: 合并后的规则字典，包含 harmful_rule, benign_rule, rule_text
        """
        import json
        import re
        
        # 1. 提取现有 cluster 的规则
        existing_harmful_rule = ""
        existing_benign_rule = ""
        
        # 优先从 benign_boundary_rule 中提取
        if hasattr(existing_node, 'benign_boundary_rule') and existing_node.benign_boundary_rule:
            if isinstance(existing_node.benign_boundary_rule, dict):
                existing_harmful_rule = existing_node.benign_boundary_rule.get('harmful_rule', '')
                existing_benign_rule = existing_node.benign_boundary_rule.get('benign_rule', '') or existing_node.benign_boundary_rule.get('exemptions', '')
        
        # 如果没有，从 rule_dict 中提取
        if not existing_harmful_rule:
            if existing_node.value and 'rule_dict' in existing_node.value:
                rule_dict = existing_node.value['rule_dict']
                if isinstance(rule_dict, dict):
                    existing_harmful_rule = rule_dict.get('harmful_rule', '')
                    existing_benign_rule = rule_dict.get('benign_rule', '')
        
        # 如果还没有，从 defense_strategy 或 defense_response 中提取
        if not existing_harmful_rule:
            existing_rule_text = getattr(existing_node, 'defense_strategy', '') or (existing_node.value.get('defense_response', '') if existing_node.value else '')
            if existing_rule_text:
                # 尝试从 rule_text 中提取 harmful_rule 和 benign_rule
                if 'however' in existing_rule_text.lower() or 'allow' in existing_rule_text.lower():
                    parts = re.split(r'\s+however\s+|\s+but\s+allow\s+', existing_rule_text, flags=re.IGNORECASE)
                    existing_harmful_rule = parts[0].strip() if parts else existing_rule_text
                    existing_benign_rule = parts[1].strip() if len(parts) > 1 else ''
                else:
                    existing_harmful_rule = existing_rule_text
        
        # 2. 提取新样本的规则
        new_harmful_rule = ""
        new_benign_rule = ""
        
        # 优先从 rule_dict 中提取
        if 'rule_dict' in new_value and new_value['rule_dict']:
            rule_dict = new_value['rule_dict']
            if isinstance(rule_dict, dict):
                new_harmful_rule = rule_dict.get('harmful_rule', '')
                new_benign_rule = rule_dict.get('benign_rule', '')
        
        # 如果没有，从 benign_boundary_rule 中提取
        if not new_harmful_rule and 'benign_boundary_rule' in new_value and new_value['benign_boundary_rule']:
            if isinstance(new_value['benign_boundary_rule'], dict):
                new_harmful_rule = new_value['benign_boundary_rule'].get('harmful_rule', '')
                new_benign_rule = new_value['benign_boundary_rule'].get('benign_rule', '') or new_value['benign_boundary_rule'].get('exemptions', '')
        
        # 如果还没有，从 defense_response 中提取
        if not new_harmful_rule:
            new_rule_text = new_value.get('defense_response', '')
            if new_rule_text:
                # 检查是否是默认规则格式
                if new_rule_text.startswith("Refuse requests related to:") and len(new_rule_text) < 100:
                    # 默认规则：不应该进入 _merge_rules_with_llm，规则应该在 add_new_leaf 中生成
                    print(f"⚠️ [_merge_rules_with_llm] Default rule detected, should have been generated in add_new_leaf")
                    # 使用空规则继续，让合并逻辑处理
                    new_harmful_rule = new_rule_text
                else:
                    # 非默认规则，提取 harmful_rule 部分
                    if 'however' in new_rule_text.lower() or 'allow' in new_rule_text.lower():
                        parts = re.split(r'\s+however\s+|\s+but\s+allow\s+', new_rule_text, flags=re.IGNORECASE)
                        new_harmful_rule = parts[0].strip() if parts else new_rule_text
                        new_benign_rule = parts[1].strip() if len(parts) > 1 else ''
                    else:
                        new_harmful_rule = new_rule_text
        
        # 3. 构建 LLM prompt 来合并规则
        merge_prompt = f"""你是一个安全规则合并专家。现在需要合并两个相似的攻击样本和规则，请决定应该保留哪些内容。

[现有节点信息]
恶性规则 (harmful_rule): {existing_harmful_rule if existing_harmful_rule else '无'}
良性规则 (benign_rule): {existing_benign_rule if existing_benign_rule else '无'}

[新节点信息]
恶性规则 (harmful_rule): {new_harmful_rule if new_harmful_rule else '无'}
良性规则 (benign_rule): {new_benign_rule if new_benign_rule else '无'}

[任务]
请分析这两个节点的内容，决定合并后应该保留哪些信息。原则：
1. 保留更通用、更抽象的内容
2. 如果两个规则相似，保留更完整的版本
3. 恶性规则和良性规则应该分别合并，保留更全面的场景

请以JSON格式输出，格式如下：
{{
    "merged_harmful_rule": "合并后的恶性规则（选择更通用、更完整的）",
    "merged_benign_rule": "合并后的良性规则（合并两个规则，保留更全面的豁免场景）",
    "merged_rule_text": "合并后的完整规则（由 merged_harmful_rule 和 merged_benign_rule 组合）",
    "reasoning": "说明为什么选择这些内容"
}}

注意：
- harmful_rule 和 benign_rule 应该分别更新，不要混淆
- 如果某个字段在现有节点或新节点中都不存在，可以设为空字符串
- merged_rule_text 应该由 merged_harmful_rule 和 merged_benign_rule 组合而成
- 请只输出JSON，不要包含其他内容。"""

        # 4. 调用 LLM
        try:
            rule_resp = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": merge_prompt}],
                temperature=0.1,
                max_tokens=1024,
                timeout=90
            )
            
            response_text = rule_resp.choices[0].message.content.strip()
            
            # 5. 解析 LLM 响应
            # 尝试提取 JSON（可能包含在代码块中）
            json_match = re.search(r'\{[\s\S]*\}', response_text)
            if json_match:
                response_text = json_match.group(0)
            
            merge_result = json.loads(response_text)
            
            # 6. 提取合并后的规则
            merged_harmful_rule = merge_result.get('merged_harmful_rule', existing_harmful_rule or new_harmful_rule)
            merged_benign_rule = merge_result.get('merged_benign_rule', existing_benign_rule or new_benign_rule)
            merged_rule_text = merge_result.get('merged_rule_text', '')
            
            # 如果没有 merged_rule_text，从 merged_harmful_rule 和 merged_benign_rule 组合
            if not merged_rule_text:
                if merged_harmful_rule and merged_benign_rule:
                    merged_rule_text = f"{merged_harmful_rule}, however, {merged_benign_rule}"
                elif merged_harmful_rule:
                    merged_rule_text = merged_harmful_rule
                else:
                    merged_rule_text = merged_benign_rule
            
            return {
                'harmful_rule': merged_harmful_rule,
                'benign_rule': merged_benign_rule,
                'rule_text': merged_rule_text
            }
            
        except Exception as e:
            print(f"⚠️ [Merge Rules Error] Failed to merge rules with LLM: {e}")
            import traceback
            traceback.print_exc()
            # Fallback: 返回现有规则或新规则
            return {
                'harmful_rule': existing_harmful_rule or new_harmful_rule or "Refuse requests related to: harmful content",
                'benign_rule': existing_benign_rule or new_benign_rule,
                'rule_text': existing_harmful_rule or new_harmful_rule or "Refuse requests related to: harmful content"
            }

    def _find_most_similar_leaf(self, cluster_node, new_embedding, new_value):
        """
        找到 cluster 中最相似的叶子节点
        
        Args:
            cluster_node: cluster 节点
            new_embedding: 新样本的 embedding
            new_value: 新样本的 value 字典
            
        Returns:
            最相似的叶子节点，如果没有则返回 None
        """
        if not cluster_node.children:
            return None
        
        best_leaf = None
        highest_sim = -1
        
        for leaf_node in cluster_node.children:
            if not leaf_node.value or 'embedding' not in leaf_node.value:
                continue
            
            leaf_embedding = leaf_node.value['embedding']
            sim = self._cosine_similarity(leaf_embedding, new_embedding)
            
            if sim > highest_sim:
                highest_sim = sim
                best_leaf = leaf_node
        
        if best_leaf:
            print(f"   ✓ Found most similar leaf node (similarity: {highest_sim:.3f})")
        
        return best_leaf

    def _handle_fallback_strategy(self, node, new_value):
        """
        [Helper] 兜底策略处理逻辑
        原则：优先保留现有的通用规则，只有当节点还是空的时候，才用新样本的规则填充。
        防止用"特例"覆盖了"通例"。
        """
        # 1. 获取新样本带来的策略
        new_strategy = new_value.get('defense_response', '')
        
        # 2. 获取节点现有的策略
        current_strategy = getattr(node, 'defense_strategy', None)
        if not current_strategy and node.value:
            current_strategy = node.value.get('defense_response', None)
            
        # 3. 决策逻辑
        if current_strategy:
            # 如果已有策略，保持不变 (因为已有策略通常比单一样本更通用)
            # 或者你可以选择 append: node.defense_strategy += f"; {new_strategy}"
            pass 
        elif new_strategy:
            # 如果没有策略，才使用新样本的策略初始化
            node.defense_strategy = new_strategy
            if not node.value: node.value = {}
            node.value['defense_response'] = new_strategy
            print(f"⚠️ Init strategy from sample (Fallback): {new_strategy[:50]}...")
        else:
            # 都没有，设为默认
            default_rule = "Refuse requests that violate safety policies."
            node.defense_strategy = default_rule
            if not node.value: node.value = {}
            node.value['defense_response'] = default_rule
    

    # 假设这是 TreeNode 类或 MemoryManager 中的辅助函数
    def update_node_center_from_children(self, node):
        """
        [Minimal Change] 不依赖 sample_count 字段，直接从 children 列表重算中心。
        优点：不需要修改 TreeNode 数据结构。
        缺点：Cluster 很大时有一点点慢，但在 Memory Tree 规模下完全可忽略。
        """
        if not node.children:
            return

        # 1. 取出所有子节点的 embedding (全是 Attack Prompt)
        # 注意：确保子节点的 value 里存了 'embedding'
        child_embeddings = [np.array(child.value['embedding'], dtype=np.float32).flatten() for child in node.children]
        
        if not child_embeddings:
            return

        # 2. 计算最大值 (Max) - 每个维度取最大值
        # axis=0 表示对所有向量的对应维度求最大值
        max_vec = np.max(child_embeddings, axis=0)
        
        # 3. 归一化 (Normalization) - 这一步对 Cosine Sim 至关重要
        norm = np.linalg.norm(max_vec)
        if norm > 0:
            max_vec = max_vec / norm
            
        # 4. 更新当前节点的中心
        node.center_embedding = max_vec


    def add_new_leaf(self, parent_node, value, is_cluster_root=False):
        """
        [Modified] 添加新节点
        is_cluster_root: True 表示这是 Layer 2 的一个新 Cluster (Sub-category)
        """
        import numpy as np
        
        # ============================================================
        # 第一步：统一生成规则（无论是否是 cluster_root）
        # ============================================================
        query_text = value.get('attack_prompt', '')
        boundary_rule = None
        rule_text = ""
        harmful_rule = ""
        benign_rule = ""
        
        if query_text and query_text.strip():
            try:
                # 收集有害样本
                harmful_list = [query_text]
                
                # 收集良性样本
                benign_list = []
                if self.defense_vector_store:
                    try:
                        results = self.defense_vector_store.similarity_search(query_text, k=3)
                        for doc in results:
                            intent = doc.metadata.get("query", "") or doc.metadata.get("intent", "")
                            if intent and intent.strip():
                                benign_list.append(intent.strip())
                    except Exception as e:
                        print(f"   ⚠️ [Benign Retrieval Error] Failed to retrieve benign samples: {e}")
                
                # [统一规则生成] 调用 generate_boundary_rule，一次性生成 harmful_rule 和 benign_rule
                boundary_rule = self.generate_boundary_rule(harmful_list, benign_list)
                if boundary_rule:
                    rule_text = boundary_rule.get('rule_text', '')
                    harmful_rule = boundary_rule.get('harmful_rule', '')
                    benign_rule = boundary_rule.get('benign_rule', '') or boundary_rule.get('exemptions', '')
                    
                    # 如果没有 rule_text，从 harmful_rule 和 benign_rule 组合
                    if not rule_text:
                        if harmful_rule and benign_rule:
                            rule_text = f"{harmful_rule}, however, {benign_rule}"
                        elif harmful_rule:
                            rule_text = harmful_rule
                        else:
                            rule_text = benign_rule
                    
                    print(f"   >>> Generated rule (harmful + benign) for {'cluster' if is_cluster_root else 'leaf'}")
                else:
                    print(f"   ⚠️ Warning: LLM returned None when generating rule")
            except Exception as e:
                print(f"   ⚠️ Failed to generate rule: {e}")
                import traceback
                traceback.print_exc()

        # 如果没有生成规则，使用默认规则
        if not rule_text:
            rule_text = f"Refuse requests related to: {value.get('sub_category', 'harmful content')}"
        
        # 构建规则字典
        rule_dict = {
            'harmful_rule': harmful_rule or rule_text,
            'benign_rule': benign_rule,
            'rule_text': rule_text,
            'topic_label': boundary_rule.get('cluster_topic', 'harmful content') if boundary_rule else 'harmful content'
        }
        
        # ============================================================
        # 第二步：根据是否是 cluster_root 操作树
        # ============================================================
        if is_cluster_root:
            # 创建新的聚类节点 (Dynamic Defense Node)
            cluster_label = f"Cluster_{value['id']}"
            cluster_node = TreeNode(cluster_label)
            cluster_node.set_value(value)
            
            # 设置 embedding
            if 'embedding' in value:
                emb = np.array(value['embedding'], dtype=np.float32).flatten()
                norm = np.linalg.norm(emb)
                if norm > 0:
                    emb = emb / norm
                cluster_node.center_embedding = emb
            else:
                cluster_node.center_embedding = None
            
            # 设置 cluster 节点的规则
            cluster_node.defense_strategy = rule_text
            cluster_node.value['defense_response'] = rule_text
            cluster_node.value['rule_dict'] = rule_dict
            if boundary_rule:
                cluster_node.benign_boundary_rule = boundary_rule
            
            # 创建叶子节点并挂到 cluster 下
            actual_leaf = TreeNode(value['id'])
            actual_leaf.set_value(value)
            actual_leaf.value['defense_response'] = rule_text
            actual_leaf.value['rule_dict'] = rule_dict
            if boundary_rule:
                actual_leaf.benign_boundary_rule = boundary_rule
            cluster_node.add_child(actual_leaf)
            
            # 将 cluster 节点添加到父节点
            parent_node.add_child(cluster_node)
        else:
            # 普通叶子节点
            leaf_node = TreeNode(value['id'])
            leaf_node.set_value(value)
            leaf_node.value['defense_response'] = rule_text
            leaf_node.value['rule_dict'] = rule_dict
            if boundary_rule:
                leaf_node.benign_boundary_rule = boundary_rule
            
            parent_node.add_child(leaf_node)
            
            # 更新父节点的 center embedding
            self.update_node_center_from_children(parent_node)

    def _calculate_ig(self, existing_embeddings, new_embedding, defense_success=True, temperature=0.1):
        """
        [Deprecated] 计算"节点脆弱性熵"（基于 attack_prompt embedding）
        
        注意：此方法已不再使用，add_node 现在基于规则信息增益决策。
        保留此方法仅用于向后兼容或调试。
        
        简化版：统一熵值尺度到 0~1，去掉魔法数字
        """
        import numpy as np
        
        if not existing_embeddings:
            return 1.0  # 第一个节点，默认为高熵，建立新节点

        # 计算与现有 Cluster 中心的相似度
        sims = [self._cosine_similarity(embed, new_embedding) for embed in existing_embeddings]
        sims = np.array([max(s, 0.001) for s in sims])  # 避免除0

        # 如果与所有节点的相似度都很低，说明是 OOD 数据，熵很高
        max_sim = np.max(sims)
        if max_sim < 0.5:
            # 统一到 0~1 范围：相似度越低，熵越高
            return 1.0 - max_sim * 2  # max_sim=0 → entropy=1.0, max_sim=0.5 → entropy=0.0

        # 标准熵计算 (Softmax distribution)
        # 如果新样本同时像 A 和 B (sim_A ≈ sim_B)，熵会高 -> Split (因为边界模糊)
        # 如果新样本只像 A (sim_A >> sim_B)，熵会低 -> Merge into A
        probs = sims / np.sum(sims)
        # Sharpen distribution
        probs = probs ** (1.0 / temperature)
        probs = probs / np.sum(probs)
        
        entropy = self._calculate_entropy(probs)
        
        # 归一化熵 (0~1)
        k = len(existing_embeddings)
        if k > 1:
            entropy = entropy / np.log2(k)
        
        return entropy

    def _calculate_rule_entropy_gain(self, cluster_node, new_rule_text):
        """
        计算加入新规则后的信息增益（熵增）
        
        基于规则 embedding 计算（只使用 harmful_rule 部分）：
        - 如果新规则与现有规则相似度高 → 熵增小（信息增益低）→ 不应该 merge（不能丰富 cluster）
        - 如果新规则与现有规则相似度低 → 熵增大（信息增益高）→ 应该 merge（可以丰富 cluster）
        
        Args:
            cluster_node: 目标 cluster 节点
            new_rule_text: 新样本的 defense_response 规则文本（可以是字符串或字典）
        
        Returns:
            float: 信息增益（熵增），值越大说明新规则带来的信息越多，越应该 merge
        """
        import numpy as np
        import re
        
        # 提取 harmful_rule 部分用于信息增益计算
        if isinstance(new_rule_text, dict):
            # 如果是字典格式，提取 harmful_rule
            new_rule_text = new_rule_text.get('harmful_rule', '') or new_rule_text.get('rule_text', '')
        elif isinstance(new_rule_text, str):
            # 如果是字符串，尝试提取 harmful_rule 部分（如果包含 exemption）
            if 'however' in new_rule_text.lower() or 'allow' in new_rule_text.lower():
                parts = re.split(r'\s+however\s+|\s+but\s+allow\s+', new_rule_text, flags=re.IGNORECASE)
                new_rule_text = parts[0].strip() if parts else new_rule_text
        
        if not new_rule_text or not new_rule_text.strip():
            return 0.0
        
        # 1. 收集 cluster 中所有现有规则的文本（只使用 harmful_rule 部分用于信息增益计算）
        existing_rules = []
        for child in cluster_node.children:
            if child.value and 'defense_response' in child.value:
                rule_text = child.value['defense_response']
                if rule_text and rule_text.strip():
                    existing_rules.append(rule_text.strip())
        
        # 如果 cluster 有 defense_strategy，也加入
        if hasattr(cluster_node, 'defense_strategy') and cluster_node.defense_strategy:
            existing_rules.append(cluster_node.defense_strategy.strip())
        
        if not existing_rules:
            # 如果没有现有规则，新规则的信息增益最大
            return 1.0
        
        # 2. 计算现有规则的 embedding
        existing_embeddings = []
        for rule_text in existing_rules:
            try:
                rule_emb = self.embedding_model.encode(rule_text, normalize_embeddings=True)
                rule_emb = np.array(rule_emb, dtype=np.float32).flatten()
                norm = np.linalg.norm(rule_emb)
                if norm > 0:
                    rule_emb = rule_emb / norm
                existing_embeddings.append(rule_emb)
            except Exception:
                continue
        
        if not existing_embeddings:
            return 1.0
        
        # 3. 计算新规则的 embedding
        try:
            new_rule_emb = self.embedding_model.encode(new_rule_text, normalize_embeddings=True)
            new_rule_emb = np.array(new_rule_emb, dtype=np.float32).flatten()
            norm = np.linalg.norm(new_rule_emb)
            if norm > 0:
                new_rule_emb = new_rule_emb / norm
        except Exception:
            return 0.0
        
        # 【冷启动处理】如果只有1个现有规则，直接基于相似度判断
        if len(existing_embeddings) == 1:
            # 冷启动：只有1个规则，直接基于新规则与现有规则的相似度判断
            similarity = np.dot(new_rule_emb, existing_embeddings[0])
            # 相似度高 → 信息增益低（应该 merge）
            # 相似度低 → 信息增益高（应该新建 cluster）
            # 使用 1 - similarity 作为信息增益的代理
            # 相似度 0.7 → 信息增益 0.3（< 0.6，触发 merge）
            # 相似度 0.3 → 信息增益 0.7（> 0.6，触发新建 cluster）
            info_gain = 1.0 - similarity
            return max(0.0, info_gain)
        
        # 4. 计算加入前的熵（基于现有规则之间的相似度分布）
        # 计算现有规则两两之间的相似度
        existing_sims = []
        for i, emb1 in enumerate(existing_embeddings):
            for j, emb2 in enumerate(existing_embeddings):
                if i < j:
                    sim = np.dot(emb1, emb2)
                    existing_sims.append(sim)
        
        # 【冷启动处理】当规则数量较少时（≤3个），简化熵计算
        if len(existing_embeddings) <= 3:
            # 当只有2个规则时，existing_sims 只有1个值，熵计算不稳定
            # 直接基于新规则与现有规则的最大相似度判断
            new_sims = [np.dot(new_rule_emb, emb) for emb in existing_embeddings]
            max_similarity = max(new_sims) if new_sims else 0.0
            
            # 相似度高 → 信息增益低（应该 merge）
            # 相似度低 → 信息增益高（应该新建 cluster）
            info_gain = 1.0 - max_similarity
            return max(0.0, info_gain)
        
        # 5. 标准熵计算（当规则数量 > 3 时）
        if not existing_sims:
            entropy_before = 0.0
        else:
            # 将相似度转换为概率分布（使用 softmax）
            sims_array = np.array(existing_sims)
            sims_array = np.clip(sims_array, 0.001, 1.0)  # 避免0或负数
            probs = sims_array / np.sum(sims_array)
            entropy_before = -np.sum(probs * np.log2(probs + 1e-10))
        
        # 6. 计算加入后的熵（加入新规则与现有规则的相似度）
        new_sims = [np.dot(new_rule_emb, emb) for emb in existing_embeddings]
        all_sims = existing_sims + new_sims
        
        if not all_sims:
            entropy_after = 0.0
        else:
            sims_array = np.array(all_sims)
            sims_array = np.clip(sims_array, 0.001, 1.0)
            probs = sims_array / np.sum(sims_array)
            entropy_after = -np.sum(probs * np.log2(probs + 1e-10))
        
        # 7. 信息增益 = 熵增（加入后 - 加入前）
        info_gain = entropy_after - entropy_before
        
        # 归一化到 0~1（根据实际熵值范围调整）
        # 如果信息增益 > 0，说明新规则带来了新信息
        # 如果信息增益 ≈ 0，说明新规则与现有规则相似
        
        return max(0.0, info_gain)

    def _project_embedding(self, embedding):
        """
        使用 Safety Projector 投影 embedding 到安全空间
        如果未加载 Safety Projector，则返回harmful score
        """
        if not self.use_safety_projection or self.safety_projector is None:
            return embedding
        
        import torch
        # 从模型的参数中获取设备，确保一致性
        device = next(self.safety_projector.parameters()).device
        
        # 转换为 torch tensor 并移动到指定设备
        emb_tensor = torch.tensor(embedding, dtype=torch.float32, device=device).unsqueeze(0)
        
        # 投影
        with torch.no_grad():
            projected, logits = self.safety_projector(emb_tensor)
            # 获取概率
            prob_harmful = torch.sigmoid(logits).item() # 0~1 之间的数
        
        # 转换回 numpy（如果 tensor 在 GPU 上，需要先移到 CPU）
        if projected.is_cuda:
            projected = projected.cpu()
        return projected.squeeze(0).numpy(), prob_harmful

    def retrieve_query(self, messages, top_k=3, prompt=False, top_k_clusters=None, top_k_leaves_per_cluster=3):
        """
        [Refactored] 核心路由函数 - 返回结构化结果用于两次LLM调用
        整合 Safety Projector (概率) 和 Static RAG (相似度) 进行决策。
        
        Args:
            messages: 用户消息列表
            top_k: RAG检索的top-k数量（向后兼容，实际使用 top_k_clusters）
            prompt: 如果为True，返回message append之后的列表；如果为False，返回字典
            top_k_clusters: 检索的 cluster 数量（默认: top_k，如果为 None）
            top_k_leaves_per_cluster: 每个 cluster 中检索的叶子节点数量（默认: 3）
        
        Returns:
            如果 prompt=False:
                dict: 包含以下字段
                    - is_harmful: bool - 是否应该拒绝（基于分支判断）
                    - rag_content: str - RAG检索到的内容（用于第二次调用）
                    - dynamic_rule: str - 动态规则（用于第一次判断）
                    - branch: str - 分支类型 (SAFE/BLOCK/AMBIGUOUS)
                    - original_messages: list - 原始messages（用于第二次调用）
                    - topic_label: str - 主题标签
            如果 prompt=True:
                list: 已将RAG内容注入的messages列表
        """
        # 如果没有指定 top_k_clusters，使用 top_k 作为默认值
        requested_top_k_clusters = top_k_clusters  # 保存原始请求值（用于打印）
        if top_k_clusters is None:
            top_k_clusters = top_k
            requested_top_k_clusters = top_k
        
        # 保存 top_k_clusters 的数值（用于切片），因为后面可能会被赋值为列表
        # 确保 num_top_k_clusters 是整数
        if isinstance(top_k_clusters, list):
            num_top_k_clusters = len(top_k_clusters) if top_k_clusters else 1
        elif not isinstance(top_k_clusters, int):
            try:
                num_top_k_clusters = int(top_k_clusters)
            except (ValueError, TypeError):
                num_top_k_clusters = top_k if top_k else 3
        else:
            num_top_k_clusters = top_k_clusters
        
        # 1. 基础提取
        if not messages or not isinstance(messages, list):
            # 返回默认结构
            return {
                "is_harmful": False,
                "rag_content": "",
                "dynamic_rule": "",
                "dynamic_rules": [],
                "benign_boundary_rules": [],
                "branch": "SAFE",
                "original_messages": messages if messages else [],
                "topic_label": "General"
            }
        
        user_query = ""
        for m in reversed(messages):
            if m.get('role') == 'user':
                user_query = m.get('content', "")
                break
        
        if not user_query:
            # 根据 prompt 参数决定返回格式
            if prompt:
                return messages
            else:
                return {
                    "is_harmful": False,
                    "rag_content": "",
                    "dynamic_rule": "",
                    "dynamic_rules": [],
                    "benign_boundary_rules": [],
                    "branch": "SAFE",
                    "original_messages": messages,
                    "topic_label": "General"
                }
        
        # 保存原始messages
        import copy
        original_messages = copy.deepcopy(messages)

        # -------------------------------------------------------------
        # Step 1: 获取信号 (Signals)
        # -------------------------------------------------------------
        
        # A. 计算 Embedding (统一入口)
        import numpy as np
        try:
            # 确保 embedding_model 存在
            if not hasattr(self, 'embedding_model') or self.embedding_model is None:
                # 如果不存在，重新初始化
                self.embedding_model = get_embedding_model()
            
            # 直接使用 embedding_model 的 encode 方法
            query_emb = self.embedding_model.encode(user_query, normalize_embeddings=True)
            query_emb = np.array(query_emb, dtype=np.float32)
            # 归一化 (方便后续点积，虽然 encode 已经 normalize 了，但为了保险再检查一次)
            norm = np.linalg.norm(query_emb)
            if norm > 0: query_emb /= norm
        except AttributeError as e:
            # 如果 AttributeError 提到 embedding_func，说明可能是旧代码遗留问题
            if 'embedding_func' in str(e):
                print(f"⚠️ Embedding failed: {e}")
                print("   Attempting to reinitialize embedding_model...")
                try:
                    self.embedding_model = get_embedding_model()
                    query_emb = self.embedding_model.encode(user_query, normalize_embeddings=True)
                    query_emb = np.array(query_emb, dtype=np.float32)
                    norm = np.linalg.norm(query_emb)
                    if norm > 0: query_emb /= norm
                except Exception as e2:
                    print(f"⚠️ Embedding reinitialization failed: {e2}")
                    return messages
            else:
                print(f"⚠️ Embedding failed: {e}")
                return messages
        except Exception as e:
            print(f"⚠️ Embedding failed: {e}")
            return messages

        # B. 获取 Safety Projector 概率 (判别信号)
        # 【修改】使用 Safety Projector 计算 harmful_prob（而不是 cluster 相似度）
        harmful_prob = 0.50
        if self.enable_safety_projection and self.use_safety_projection and self.safety_projector:
            try:
                import torch
                with torch.no_grad():
                    # 1. 准备输入（使用 query_emb）
                    t = torch.tensor(query_emb, device=self.device).float()
                    if t.dim() == 1:
                        t = t.unsqueeze(0)
                    
                    # 2. 模型推理
                    # 返回的 logits 其实是: similarity / temperature
                    _, logits = self.safety_projector(t)
                    
                    # 3. 获取模型内部的温度系数
                    current_temp = getattr(self.safety_projector, 'temperature', torch.tensor(1.0)).item()
                    
                    # 4. 还原余弦相似度
                    raw_logits = logits * current_temp
                    sim_safe = raw_logits[0, 0].item()   # 与安全原型的相似度
                    sim_harm = raw_logits[0, 1].item()   # 与有害原型的相似度
                    
                    # 5. 计算线性分数 (Linear Score)
                    # 逻辑：基础分 0.5 + (偏向有害的程度 / 2)
                    # 这里的 /2 是因为 cosine 差值最大范围是 2 (从 1 到 -1)
                    linear_score = 0.5 + (sim_harm - sim_safe) / 2.0
                    
                    # 6. 截断到 [0, 1] 之间 (Clip)
                    harmful_prob = max(0.0, min(1.0, linear_score))
                    
            except Exception as e:
                print(f"⚠️ Projector failed: {e}")
                harmful_prob = 0.5
        else:
            # 如果没有 Safety Projector，使用默认值
            print(f"🛡️ [Signal] Safety Projector not available, using default harmful_prob: {harmful_prob:.4f}")
        
        print(f"🛡️ [Signal] Projector Harmful Prob: {harmful_prob:.4f}")

        max_benign_sim = 0.0
        if self.defense_vector_store:
            try:
                # 1. 检索 Top-K
                k_neighbors = 3
                results_with_score = self.defense_vector_store.similarity_search_with_score(user_query, k=k_neighbors)
                
                if results_with_score:
                    sim_scores = []
                    
                    for res in results_with_score:
                        if isinstance(res, (tuple, list)) and len(res) >= 2:
                            # FAISS 返回的通常是 L2 距离 (Euclidean Distance)
                            # 对于归一化向量：L2_Squared = 2 * (1 - Cosine_Sim)
                            # 所以 Cosine_Sim = 1 - (L2^2 / 2)
                            # 注意：取决于你的 FAISS 索引类型，这里假设返回的是未平方的 L2 float
                            l2_dist = res[1]
                            
                            # --- [核心修改 1: 宽松的距离映射] ---
                            # 我们不再用严苛的高斯核，改用基于 Cosine 的线性映射
                            # 假设 l2_dist 在 0 ~ 1.414 之间
                            
                            # 估算余弦相似度 (0 ~ 1)
                            # 加上 clamp 防止浮点误差导致越界
                            # 如果你的向量未归一化，这个公式可能不准，但趋势是对的
                            cos_sim = max(0.0, 1.0 - (l2_dist ** 2) / 2.0)
                            
                            sim_scores.append(cos_sim)
                        else:
                            continue

                    # --- [核心修改 2: 改变聚合策略] ---
                    # 场景：从 Block 列表中“捞人”。
                    # 逻辑：只要有一个良性样本非常像 (Top-1)，就说明这可能是误判。
                    # 均值 (Mean) 会被较远的 Top-3 拖累，导致救不回来。
                    
                    if sim_scores:
                        # 只使用最大值，不使用加权平均
                        max_benign_sim = max(sim_scores)
                        
            except Exception as e:
                print(f"⚠️ FAISS search failed: {e}")
                max_benign_sim = 0.0
        

        print(f"🔍 [Signal] Benign Similarity: {max_benign_sim:.4f}")

        # D. 寻找最相似的 Cluster
        best_cluster = self._find_cluster_for_update(query_emb)
        
        # 简化逻辑：只找最相似的1个cluster，然后从这个cluster里取top-k条规则
        benign_boundary_rules = []
        
        import numpy as np
        # Query 归一化（如果还没有归一化）
        query_vec = np.array(query_emb, dtype=np.float32).flatten()
        q_norm = np.linalg.norm(query_vec)
        if q_norm > 0:
            query_vec = query_vec / q_norm
        
        # 打印查询信息
        print(f"\n🔍 [Retrieval] Query: {user_query[:100]}...")
        print(f"  Top-k clusters: {num_top_k_clusters}, Rules per cluster: 1")
        
        # 搜索所有cluster，计算相似度
        all_cluster_scores = []
        
        for cat_node in self.root.children:
            for cluster_node in cat_node.children:
                quality_score, similarity = self._calculate_cluster_quality_score(cluster_node, query_vec)
                all_cluster_scores.append((cluster_node, similarity, cat_node.label))
        
        if not all_cluster_scores:
            print(f"  ⚠️ No clusters found at all")
        else:
            # 按相似度排序，取top-k个cluster
            all_cluster_scores.sort(key=lambda x: x[1], reverse=True)
            top_k_clusters = all_cluster_scores[:num_top_k_clusters]
            print(f"  ✓ Selected {len(top_k_clusters)} clusters (requested {num_top_k_clusters})")
            for i, (cluster_node, sim, cat_label) in enumerate(top_k_clusters, 1):
                print(f"    {i}. {cluster_node.label[:30]}, similarity: {sim:.3f}, category: {cat_label}")
        
        # 从每个cluster中提取1条最相似的规则
        for cluster_node, similarity, category_label in top_k_clusters:
            # 计算每个叶子节点与查询的相似度
            leaf_similarities = []
            if not cluster_node.children:
                continue
            for leaf_node in cluster_node.children:
                if leaf_node.value and 'embedding' in leaf_node.value:
                    try:
                        leaf_emb = np.array(leaf_node.value['embedding'], dtype=np.float32).flatten()
                        leaf_norm = np.linalg.norm(leaf_emb)
                        if leaf_norm > 0:
                            leaf_emb = leaf_emb / leaf_norm
                            leaf_sim = np.dot(query_vec, leaf_emb)
                            
                            # 检查叶子节点是否有 benign_boundary_rule 或 defense_response（恶性规则）
                            leaf_benign_rule = getattr(leaf_node, 'benign_boundary_rule', None)
                            leaf_harmful_rule = None
                            if leaf_node.value:
                                leaf_harmful_rule = leaf_node.value.get('defense_response', '')
                            
                            # 只要有良性规则或恶性规则，就加入候选列表
                            if leaf_benign_rule or (leaf_harmful_rule and leaf_harmful_rule.strip()):
                                leaf_similarities.append((leaf_node, leaf_sim, leaf_benign_rule, leaf_harmful_rule))
                    except Exception as e:
                        print(f"  ⚠️ Error calculating similarity for leaf {leaf_node.label}: {e}")
                        continue
            
            # 按相似度排序，只取最相似的1个叶子节点
            leaf_similarities.sort(key=lambda x: x[1], reverse=True)
            top_leaf = leaf_similarities[0] if leaf_similarities else None
            
            if top_leaf:
                leaf_node, leaf_sim, leaf_benign_rule, leaf_harmful_rule = top_leaf
                print(f"  📋 Cluster {cluster_node.label[:30]}: Found {len(leaf_similarities)} leaves, selected top 1 (sim={leaf_sim:.3f})")
                
                # 提取规则（每个cluster只取1条）
                # 合并每个叶子节点的恶性规则和良性规则为一条规则
                harmful_rule_text = ""
                if leaf_harmful_rule and leaf_harmful_rule.strip():
                    harmful_rule_text = leaf_harmful_rule.strip()
                
                benign_rule_text = ""
                benign_exemptions = ""
                cluster_topic = ""
                
                if isinstance(leaf_benign_rule, dict):
                    benign_rule_text = leaf_benign_rule.get('rule_text', '')
                    benign_exemptions = leaf_benign_rule.get('exemptions', '')
                    cluster_topic = leaf_benign_rule.get('cluster_topic', '')
                elif isinstance(leaf_benign_rule, str):
                    benign_rule_text = leaf_benign_rule
                
                # 合并为一条规则（包含恶性+良性）
                if harmful_rule_text or benign_rule_text:
                    benign_boundary_rules.append({
                        'harmful_rule': harmful_rule_text,  # 恶性规则
                        'rule_text': benign_rule_text,      # 良性规则（prohibition部分）
                        'exemptions': benign_exemptions,    # 良性规则（exemption部分）
                        'cluster_topic': cluster_topic,
                        'category': category_label,
                        'leaf_id': leaf_node.label,
                        'similarity': float(leaf_sim),
                        'cluster_label': cluster_node.label
                    })
                    continue  # 已经找到规则，跳过cluster级别的规则
            
            # 如果叶子节点没有规则，使用 cluster 级别的规则作为后备
            cluster_harmful_rule = ""
            cluster_benign_rule_text = ""
            cluster_benign_exemptions = ""
            cluster_benign_topic = ""
            
            # 提取 cluster 级别的规则
            if getattr(cluster_node, 'defense_strategy', None):
                cluster_harmful_rule = cluster_node.defense_strategy
            elif cluster_node.value and 'defense_response' in cluster_node.value:
                cluster_harmful_rule = cluster_node.value['defense_response']
            
            cluster_benign_rule = getattr(cluster_node, 'benign_boundary_rule', None)
            if isinstance(cluster_benign_rule, dict):
                cluster_benign_rule_text = cluster_benign_rule.get('rule_text', '')
                cluster_benign_exemptions = cluster_benign_rule.get('exemptions', '')
                cluster_benign_topic = cluster_benign_rule.get('cluster_topic', '')
            elif isinstance(cluster_benign_rule, str):
                cluster_benign_rule_text = cluster_benign_rule
            
            if cluster_harmful_rule or cluster_benign_rule_text:
                benign_boundary_rules.append({
                    'harmful_rule': cluster_harmful_rule.strip() if cluster_harmful_rule else '',
                    'rule_text': cluster_benign_rule_text,
                    'exemptions': cluster_benign_exemptions,
                    'cluster_topic': cluster_benign_topic,
                    'category': category_label,
                    'leaf_id': None,  # cluster 级别的规则
                    'cluster_label': cluster_node.label
                })
                print(f"  📋 Cluster {cluster_node.label[:30]}: Using cluster-level rule")
        
        print(f"  Total rules collected: {len(benign_boundary_rules)}")
        
        # 提取语义化的 Topic Label（使用第一个cluster的label）
        topic_label = "General_Risk"
        if top_k_clusters:
            first_cluster = top_k_clusters[0][0]
            # 优先从 value 字典里取 sub_category (比如 "Homemade Explosives")
            if first_cluster.value and 'sub_category' in first_cluster.value:
                topic_label = first_cluster.value['sub_category']
            # 其次取 category (比如 "Physical Harm")
            elif first_cluster.value and 'category' in first_cluster.value:
                topic_label = first_cluster.value['category']
            # 最后才用 label (防止是 Cluster_ID)
            elif not first_cluster.label.startswith("Cluster_"):
                topic_label = first_cluster.label

        # -------------------------------------------------------------
        # Step 2: 路由决策 (Routing) 和提取信息
        # -------------------------------------------------------------
        TH_HARM_HIGH = 0.92
        TH_HARM_LOW = 0.30     
        TH_BENIGN_HIGH = 0.80
        
        # 获取RAG内容（用于第二次调用）
        rag_content = self._get_rag_content(user_query, top_k=top_k)
        
        # 提取动态规则（保持向后兼容：从 benign_boundary_rules 中提取第一条恶性规则）
        dynamic_rules = []
        for rule_dict in benign_boundary_rules:
            if isinstance(rule_dict, dict) and rule_dict.get('harmful_rule'):
                dynamic_rules.append(rule_dict['harmful_rule'])
        
        dynamic_rule = dynamic_rules[0] if dynamic_rules else ""
        is_harmful = False
        
        # 【消融实验模式】如果启用单分支模式，将所有分支统一为 AMBIGUOUS
        if self.use_single_branch_mode:
            # 单分支模式：统一为 AMBIGUOUS，便于对比 safety projector 的效果
            branch = "AMBIGUOUS"
            is_harmful = False  # 统一设置为 False，让 LLM 判断
            
            print(f"⚖️ Route: AMBIGUOUS (Single Branch Mode - Ablation). Topic: {topic_label}")
            print(f"   Harmful Prob: {harmful_prob:.4f}, Benign Sim: {max_benign_sim:.4f}")
            
            # 如果没有找到规则，使用默认规则
            if not dynamic_rules:
                default_rule = f"Evaluate carefully. This query may relate to {topic_label}. If it violates safety policies, refuse it."
                dynamic_rules = [default_rule]
                dynamic_rule = default_rule
            
            print(f"[Memory Tree] Found {len(dynamic_rules)} top defense rules (max 3) from category")
            
            if prompt:
                first_cluster = top_k_clusters[0][0] if top_k_clusters else None
                return self._generate_rule_prompt(messages, first_cluster, harmful_prob, max_benign_sim, benign_boundary_rules)
        else:
            # 正常模式：使用三分支判断逻辑
            # 场景 1: 【误判救回】 (Safe)
            if harmful_prob < TH_HARM_LOW and max_benign_sim > TH_BENIGN_HIGH:
                branch = "SAFE"
                is_harmful = False
                print(f"✅ Route: SAFE (Rescue). Topic: {topic_label}")
                dynamic_rules = []  # SAFE 分支不需要规则
                dynamic_rule = ""
                # 保留 benign_boundary_rules，用于检索输出 rule_text
                # benign_boundary_rules 已在前面收集，这里不再清空
                if prompt:
                    return messages
                    # return self._generate_rule_prompt(messages, best_cluster, harmful_prob, max_benign_sim)
                
            # 场景 2: 【明显攻击】 (Block)
            elif harmful_prob > TH_HARM_HIGH:
                branch = "BLOCK"
                is_harmful = True
                print(f"⛔ Route: BLOCK (High Risk). Topic: {topic_label}")
                
                # 如果没有找到规则，使用默认规则
                if not dynamic_rules:
                    default_rule = f"Strictly refuse to assist with {topic_label} related activities."
                    dynamic_rules = [default_rule]
                    dynamic_rule = default_rule
                
                print(f"[Memory Tree] Found {len(benign_boundary_rules)} rules from {len(top_k_clusters)} clusters")
                if prompt:
                    first_cluster = top_k_clusters[0][0] if top_k_clusters else None
                    return self._generate_reframe_prompt(messages, first_cluster, harmful_prob, max_benign_sim, benign_boundary_rules)

            # 场景 3: 【模糊/良性】 (Ambiguous)
            else:
                branch = "AMBIGUOUS"
                is_harmful = False  # 模糊情况下，需要LLM进一步判断
                print(f"⚖️ Route: AMBIGUOUS. Topic: {topic_label}")
                
                # 如果没有找到规则，使用默认规则
                if not dynamic_rules:
                    default_rule = f"Evaluate carefully. This query may relate to {topic_label}. If it violates safety policies, refuse it."
                    dynamic_rules = [default_rule]
                    dynamic_rule = default_rule
                
                print(f"[Memory Tree] Found {len(benign_boundary_rules)} rules from {len(top_k_clusters)} clusters")
                if prompt:
                    first_cluster = top_k_clusters[0][0] if top_k_clusters else None
                    return self._generate_rule_prompt(messages, first_cluster, harmful_prob, max_benign_sim, benign_boundary_rules)
        
        # 记录日志
        self._log_score(user_query, harmful_prob, max_benign_sim, topic_label, branch)
        
        # 返回结构化结果
        return {
            "is_harmful": is_harmful,
            "rag_content": rag_content,
            "dynamic_rule": dynamic_rule,  # 保持向后兼容：第一条规则
            "dynamic_rules": dynamic_rules,  # 新增：该 category 下的所有规则列表
            "benign_boundary_rules": benign_boundary_rules,  # 新增：良性边界规则列表
            "branch": branch,
            "original_messages": original_messages,
            "topic_label": topic_label,
            "harmful_prob": harmful_prob,
            "max_benign_sim": max_benign_sim
        }
    
    def _find_cluster_for_update(self, embedding, category_filter=None):
        """
        [Pure Similarity] 计算 User Query 与 Cluster Center (Attack Prompt Avg) 的相似度
        """
        import numpy as np
        
        best_cluster = None
        max_sim = -1.0
        
        # 1. Query 归一化
        query_vec = np.array(embedding, dtype=np.float32).flatten()
        q_norm = np.linalg.norm(query_vec)
        if q_norm > 0:
            query_vec = query_vec / q_norm

        # 2. 确定搜索范围 (复用之前的逻辑，先定大类)
        target_categories = self.root.children 
        if category_filter:
            cat = self._get_category_node_by_label(category_filter)
            if cat: target_categories = [cat]

        # 3. 纯相似度搜索
        for cat_node in target_categories:
            for cluster_node in cat_node.children:
                
                # 直接取 center_embedding (它现在代表 Attack Prompts 的平均)
                center_vec = cluster_node.center_embedding
                
                if center_vec is None:
                    continue
                
                # 计算 Cosine Similarity
                # query_vec 已经归一化了，需要确保 center_vec 也归一化
                try:
                    # 确保 center_vec 是 array
                    c_vec = np.array(center_vec, dtype=np.float32).flatten()
                    # 【修复】确保归一化（防止新节点未归一化的情况）
                    c_norm = np.linalg.norm(c_vec)
                    if c_norm > 0:
                        c_vec = c_vec / c_norm
                    # 现在两个向量都归一化了，点积就是余弦相似度
                    sim = np.dot(query_vec, c_vec)
                    
                    if sim > max_sim:
                        max_sim = sim
                        best_cluster = cluster_node
                except Exception:
                    continue

        return best_cluster
    
    def _calculate_cluster_similarity(self, query_vec, cluster_node, use_samples=True, max_samples=10):
        """
        计算 query 与 cluster 的综合相似度（改进版）
        
        Args:
            query_vec: 归一化的查询向量
            cluster_node: cluster 节点
            use_samples: 是否使用 cluster 内样本
            max_samples: 最多使用的样本数
        
        Returns:
            float: 综合相似度 (0~1)
        """
        import numpy as np
        
        center_vec = cluster_node.center_embedding
        if center_vec is None:
            return 0.0
        
        # 1. 归一化 center embedding
        c_vec = np.array(center_vec, dtype=np.float32).flatten()
        c_norm = np.linalg.norm(c_vec)
        if c_norm > 0:
            c_vec = c_vec / c_norm
        
        # 2. 计算与 center 的相似度
        center_sim = np.dot(query_vec, c_vec)
        
        if not use_samples or not cluster_node.children:
            return center_sim
        
        # 3. 计算与 cluster 内样本的最大相似度（更准确）
        sample_sims = []
        sample_count = 0
        
        for child in cluster_node.children:
            if sample_count >= max_samples:
                break
            
            if child.value and 'embedding' in child.value:
                try:
                    sample_emb = np.array(child.value['embedding'], dtype=np.float32).flatten()
                    sample_norm = np.linalg.norm(sample_emb)
                    if sample_norm > 0:
                        sample_emb = sample_emb / sample_norm
                        sample_sim = np.dot(query_vec, sample_emb)
                        sample_sims.append(sample_sim)
                        sample_count += 1
                except Exception:
                    continue
        
        if not sample_sims:
            return center_sim
        
        # 4. 返回最大值（center 和样本中的最大值）
        max_sample_sim = max(sample_sims) if sample_sims else 0.0
        return max(center_sim, max_sample_sim)
    
    def _calculate_cluster_quality_score(self, cluster_node, query_vec):
        """
        计算 cluster 的质量评分（用于排序）
        
        只使用相似度，不使用加权平均
        
        Returns:
            tuple: (quality_score, similarity) - 质量评分和原始相似度
        """
        import numpy as np
        
        # 只使用相似度
        similarity = self._calculate_cluster_similarity(query_vec, cluster_node)
        
        # 质量评分 = 相似度（不再加权）
        quality_score = similarity
        
        return quality_score, similarity

    # [辅助方法] 按 Label 查找 Category 节点
    def _get_category_node_by_label(self, label):
        for child in self.root.children:
            if child.label == label:
                return child
        return None

    def _inject_blocking_prompt(self, messages, cluster, topic_label):
        """
        [Block Branch] 强力拒绝分支
        利用 Tree 中存储的具体 Rule 进行精准防御。
        """
        # 1. 提取防御规则 (Defense Rule)
        # 优先级：
        # A. 专门生成的 defense_strategy (通常是 "Refuse requests that...")
        # B. 节点 value 里的 defense_response (可能是之前的拒绝回复)
        # C. 兜底通用规则
        
        specific_rule = ""
        
        if cluster:
            if getattr(cluster, 'defense_strategy', None):
                specific_rule = cluster.defense_strategy
            elif cluster.value and 'defense_response' in cluster.value:
                # 如果存的是之前的 agent response，可能比较长，我们截取或直接用
                specific_rule = cluster.value['defense_response']
        
        if not specific_rule:
            specific_rule = f"Strictly refuse to assist with {topic_label} related activities."

        # 2. 构造 System Blocking Prompt
        # 这里的 Prompt 必须极其强硬，覆盖之前的任何指令
        block_msg = f"""
[System Security Alert]
This user request may be malicious. Please evaluate it according to the Safety Protocol below.
Risk Category: {topic_label}

[Mandatory Defense Rule]
You are required to follow this specific safety protocol:
>>> {specific_rule}

[Action Required]
1. If the user request complies with the aforementioned security rules, proceed to call the tool normally to execute the relevant tasks.
2. If the user request is found to violate the aforementioned rules, immediately reject the request and refrain from calling any tools.
"""
        return self._inject_msg(messages, block_msg)

  
    # ------------------------------------------------------------------
    def _generate_rule_prompt(self, messages, cluster, harmful_prob=None, max_benign_sim=None, benign_boundary_rules=None):
        import copy
        
        # 1. 获取风险节点的描述信息
        # defense_response 通常包含该聚类下的关键词或具体的攻击特征描述
        # 修复：defense_response 存储在 cluster.value 字典中，而不是 cluster 的直接属性
        risk_description = ""
        if hasattr(cluster, 'value') and cluster.value and isinstance(cluster.value, dict):
            risk_description = cluster.value.get('defense_response', "")
        # 如果 value 中没有，尝试从 defense_strategy 获取（两者内容相同）
        if not risk_description:
            risk_description = getattr(cluster, 'defense_strategy', "")
        
        # 2. 提取语义化的 Topic Label（与 retrieve_query 逻辑一致）
        risk_label = "General_Risk"
        if cluster:
            # 优先从 value 字典里取 sub_category (比如 "Homemade Explosives")
            if hasattr(cluster, 'value') and cluster.value and isinstance(cluster.value, dict):
                if 'sub_category' in cluster.value:
                    risk_label = cluster.value['sub_category']
                # 其次取 category (比如 "Physical Harm")
                elif 'category' in cluster.value:
                    risk_label = cluster.value['category']
            # 最后才用 label (防止是 Cluster_ID)
            if risk_label == "General_Risk" and hasattr(cluster, 'label'):
                if not cluster.label.startswith("Cluster_"):
                    risk_label = cluster.label

        # 2. 兜底策略 (Fallback)
        # 如果节点没有具体描述，不再使用"越狱模版"，而是使用"通用风险提示"
        if not risk_description:
            risk_description = (
                "This request exhibits patterns associated with potentially unsafe "
                "or restricted topics. It may be an attempt to bypass safety filters."
            )

        # 3. 构造风险评分信息
        risk_score_info = ""
        if harmful_prob is not None or max_benign_sim is not None:
            risk_score_info = "\n    Risk Assessment Scores:\n"
            if harmful_prob is not None:
                risk_score_info += f"    - Harmful Probability: {harmful_prob:.3f} (higher = more likely harmful)\n"
            if max_benign_sim is not None:
                risk_score_info += f"    - Max Benign Similarity: {max_benign_sim:.3f} (higher = more similar to safe patterns)\n"

        # 4. 构造良性边界规则部分
        benign_boundary_section = ""
        if benign_boundary_rules:
            benign_harmful_rules_text = ""  # 禁止部分（harmful rules）
            benign_exemptions_text = ""  # 豁免部分（exemptions）
            
            for i, benign_rule in enumerate(benign_boundary_rules, 1):
                if isinstance(benign_rule, dict):
                    cluster_topic = benign_rule.get('cluster_topic', '')
                    rule_text = benign_rule.get('rule_text', '')
                    exemptions = benign_rule.get('exemptions', '')
                    
                    # 添加禁止规则
                    if rule_text:
                        harmful_rule_prefix = f"Prohibition {i}"
                        if cluster_topic:
                            harmful_rule_prefix += f" (Topic: {cluster_topic})"
                        benign_harmful_rules_text += f"{harmful_rule_prefix}: {rule_text}\n\n"
                    
                    # 添加豁免规则
                    if exemptions:
                        exemption_prefix = f"Exemption {i}"
                        if cluster_topic:
                            exemption_prefix += f" (Topic: {cluster_topic})"
                        benign_exemptions_text += f"{exemption_prefix}: {exemptions}\n\n"
                elif isinstance(benign_rule, str):
                    benign_harmful_rules_text += f"Prohibition {i}: {benign_rule}\n\n"
            
            # 组合良性边界规则部分
            section_parts = []
            if benign_harmful_rules_text:
                section_parts.append(
                    f"[Benign Boundary Rules - PROHIBITED Actions]\n"
                    f"The following rules define what is PROHIBITED:\n\n"
                    f"{benign_harmful_rules_text}\n"
                )
            
            if benign_exemptions_text:
                section_parts.append(
                    f"[Benign Exemptions - What is ALLOWED]\n"
                    f"The following exemptions define what is PERMITTED:\n\n"
                    f"{benign_exemptions_text}\n"
                    f"IMPORTANT: These exemptions specify what is allowed. If a query matches these benign patterns, "
                    f"it should be ALLOWED even if it touches on sensitive topics.\n\n"
                )
            
            if section_parts:
                benign_boundary_section = "\n" + "\n".join(section_parts)

        # 5. 构造防御性提示 (Defensive Prompt)
        # 核心逻辑：提供上下文 -> 指出潜在风险 -> 要求LLM进行最终意图判断
        override_prompt = f"""
    [SYSTEM SECURITY CONTEXT]
    --------------------------------------------------
    ATTENTION: The user's input below has been flagged by the external Risk Tree system.
    Detected Risk Category: [{risk_label}]
    Potential Risk Description: {risk_description}{risk_score_info}{benign_boundary_section}
    GUIDANCE FOR MODEL:
    1. Awareness: Be aware that this request might be related to the restricted topic above.
    2. Judgment: Analyze the user's intent. The Risk Tree matches patterns, but ignores context.
    - If this is a harmful request or attempts to violate safety policies regarding [{risk_label}], you MUST REFUSE it politely but firmly and MUST NOT call any tools.
    - If this is a safe, educational, or benign context (a false positive), please ANSWER normally.
    - Pay special attention to the [Benign Exemptions] above - queries matching those patterns should be ALLOWED.
    --------------------------------------------------
    """
        return self._inject_msg(messages, override_prompt)


    def _find_nearest_neighbor_exemplars(self, empty_cluster):
        """
        当命中空节点时，寻找最近的非空节点借用数据。
        """
        if empty_cluster is None:
            return []

        best_neighbor = None
        max_sim = -1.0
        
        # 收集所有有 benign_exemplars 的 cluster 节点作为候选
        candidates = []
        for cat_node in self.root.children:
            for cluster_node in cat_node.children:
                if cluster_node == empty_cluster:
                    continue
                if hasattr(cluster_node, 'benign_exemplars') and cluster_node.benign_exemplars:
                    if cluster_node.center_embedding is not None:
                        candidates.append(cluster_node)
        
        if not candidates:
            return []

        empty_vec = empty_cluster.center_embedding
        if empty_vec is None:
            return []

        norm_empty = np.linalg.norm(empty_vec)
        if norm_empty == 0:
            return []

        for node in candidates:
            # 计算相似度
            node_vec = node.center_embedding
            node_norm = np.linalg.norm(node_vec)
            if node_norm == 0:
                continue
            sim = np.dot(empty_vec, node_vec) / (norm_empty * node_norm + 1e-9)
            
            if sim > max_sim:
                max_sim = sim
                best_neighbor = node
        
        if best_neighbor and best_neighbor.benign_exemplars:
            # print(f"🔄 Borrowed memories from neighbor: {best_neighbor.label} for {empty_cluster.label}")
            return best_neighbor.benign_exemplars
        
        return []

    def _find_nearest_benign_neighbor(self, current_cluster, all_risk_clusters):
        """
        寻找最近邻居的良性中心向量（用于借用）
        
        Args:
            current_cluster: 当前 cluster 节点
            all_risk_clusters: 所有风险 cluster 节点列表
            
        Returns:
            良性中心向量（numpy array），如果没有找到则返回 None
        """
        if current_cluster is None or not all_risk_clusters:
            return None
        
        best_neighbor = None
        max_sim = -1.0
        
        # 获取当前 cluster 的中心向量
        current_vec = None
        if hasattr(current_cluster, 'projected_center_embedding') and current_cluster.projected_center_embedding is not None:
            current_vec = current_cluster.projected_center_embedding
        elif hasattr(current_cluster, 'center_embedding') and current_cluster.center_embedding is not None:
            current_vec = current_cluster.center_embedding
        
        if current_vec is None:
            return None
        
        current_norm = np.linalg.norm(current_vec)
        if current_norm == 0:
            return None
        
        # 遍历所有其他 cluster，找到有良性中心且最相似的
        for cluster in all_risk_clusters:
            if cluster == current_cluster:
                continue
            
            # 优先使用投影后的良性中心
            benign_vec = None
            if hasattr(cluster, 'projected_benign_center_embedding') and cluster.projected_benign_center_embedding is not None:
                benign_vec = cluster.projected_benign_center_embedding
            elif hasattr(cluster, 'benign_center_embedding') and cluster.benign_center_embedding is not None:
                benign_vec = cluster.benign_center_embedding
            
            if benign_vec is None:
                continue
            
            # 计算相似度
            benign_norm = np.linalg.norm(benign_vec)
            if benign_norm == 0:
                continue
            
            sim = np.dot(current_vec, benign_vec) / (current_norm * benign_norm + 1e-9)
            
            if sim > max_sim:
                max_sim = sim
                best_neighbor = cluster
        
        # 返回最佳邻居的良性中心向量
        if best_neighbor:
            if hasattr(best_neighbor, 'projected_benign_center_embedding') and best_neighbor.projected_benign_center_embedding is not None:
                return best_neighbor.projected_benign_center_embedding
            elif hasattr(best_neighbor, 'benign_center_embedding') and best_neighbor.benign_center_embedding is not None:
                return best_neighbor.benign_center_embedding
        
        return None

    # ------------------------------------------------------------------
    # [Branch A Implementation] RAG-based / In-Context Learning Mode
    # 适用场景：良性或低风险任务 (Low Risk / Benign)
    # 核心逻辑：利用 "Pattern Matching" 原理，伪造成功案例，诱导模型模仿
    # ------------------------------------------------------------------
    def generate_boundary_rule(self, harmful_list, benign_list):
        """
        生成边界规则：分析有害样本和良性样本的意图差异，生成既禁止有害行为又明确允许良性行为的规则
        
        Args:
            harmful_list: 有害攻击样本列表（必须）
            benign_list: 良性使用样本列表（可选，如果没有则只生成 harmful_rule）
            
        Returns:
            dict: 包含 cluster_topic, harmful_rule, benign_rule, rule_text 的字典，解析失败返回兜底规则
        """
        if not harmful_list:
            return None
        
        # 根据是否有 benign_list 构建不同的 prompt
        if benign_list:
            prompt = f"""
You are a Safety Policy Architect. I will show you two groups of queries related to the same topic.

[Group A: HARMFUL ATTACKS] (Must be blocked)
{json.dumps(harmful_list, indent=2, ensure_ascii=False)}

[Group B: BENIGN USAGE] (Must be allowed)
{json.dumps(benign_list, indent=2, ensure_ascii=False)}

YOUR TASK:
Analyze the intent difference between Group A and Group B. 
Write a safety rule that strictly prohibits Group A but explicitly EXEMPTS Group B.

IMPORTANT: The rule_text MUST include BOTH the prohibition AND the exemption in a single coherent sentence.
The rule_text should be in the format: "Refuse [prohibition], however, allow [exemption]."

OUTPUT FORMAT (JSON):
{{
    "cluster_topic": "Short topic name (e.g., File Deletion)",
    "harmful_rule": "A strict sentence prohibiting the malicious intent found in Group A. Format: 'Refuse requests regarding [Abstracted Action] of [Abstracted Category].'",
    "benign_rule": "A specific clause listing what is ALLOWED (as seen in Group B). Format: 'However, allow [specific benign actions].'",
    "rule_text": "A complete rule that combines harmful_rule and benign_rule. Format: '[harmful_rule], however, [benign_rule]'",
    "exemptions": "Same as benign_rule (for backward compatibility)"
}}

IMPORTANT:
- harmful_rule: Only the prohibition part (what to block), separate from benign_rule
- benign_rule: Only the exemption part (what to allow), separate from harmful_rule
- rule_text: The complete rule combining both parts (can be generated by combining harmful_rule and benign_rule)
"""
        else:
            # 没有 benign_list 时，只生成 harmful_rule
            prompt = f"""
You are a Security Logic Analyst.
Your task is to extract a **logic-based defense rule** from a specific attack pattern.

[Input]
User Attack Samples:
{json.dumps(harmful_list, indent=2, ensure_ascii=False)}

[Goal]
Create a rule that captures the **underlying mechanism** of these attacks so it can block similar unseen variants.

[Abstraction Guidelines]
1. **Entity Abstraction**: Replace specific brands or items with their **functional categories**.
   - Example: "Sudafed" -> "Over-the-counter medication containing pseudoephedrine".
   - Example: "Glock 17" -> "Handguns" or "Firearms".
   - Example: "Python script" -> "Executable code or scripts".
2. **Intent Abstraction**: Define the *action* genericly.
   - Example: "How to mix..." -> "Instructions for synthesis/manufacturing".
   - Example: "Where to buy..." -> "Acquisition channels".

[Output Requirement]
- Format: "Refuse requests regarding [Abstracted Action] of [Abstracted Category]."
- Length: concise, under 50 words.
- NO specific entity names unless they are the category itself (e.g., "Nuclear").

OUTPUT FORMAT (JSON):
{{
    "cluster_topic": "Short topic name (e.g., File Deletion, Tax Evasion)",
    "harmful_rule": "The prohibition rule that blocks the attack pattern. Format: 'Refuse requests regarding [Abstracted Action] of [Abstracted Category].'",
    "benign_rule": "",
    "rule_text": "Same as harmful_rule (for backward compatibility)",
    "exemptions": ""
}}

[Output]
"""
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.1,
                max_tokens=512,
                timeout=30
            )
            
            llm_output = response.choices[0].message.content.strip()
            
            # 尝试解析 JSON（可能包含 markdown 代码块）
            if "```json" in llm_output:
                llm_output = llm_output.split("```json")[1].split("```")[0].strip()
            elif "```" in llm_output:
                llm_output = llm_output.split("```")[1].split("```")[0].strip()
            
            result = json.loads(llm_output)
            
            # 确保返回的字典包含所有必需字段
            if 'harmful_rule' not in result:
                # 兼容旧格式：从 rule_text 中提取
                rule_text = result.get('rule_text', '')
                if rule_text:
                    if 'however' in rule_text.lower() or 'allow' in rule_text.lower():
                        import re
                        parts = re.split(r'\s+however\s+|\s+but\s+allow\s+', rule_text, flags=re.IGNORECASE)
                        result['harmful_rule'] = parts[0].strip() if parts else rule_text
                        result['benign_rule'] = parts[1].strip() if len(parts) > 1 else result.get('exemptions', '')
                    else:
                        result['harmful_rule'] = rule_text
                        result['benign_rule'] = result.get('exemptions', '')
                else:
                    result['harmful_rule'] = ''
                    result['benign_rule'] = ''
            
            # 如果没有 rule_text，从 harmful_rule 和 benign_rule 组合
            if 'rule_text' not in result or not result['rule_text']:
                harmful_rule = result.get('harmful_rule', '')
                benign_rule = result.get('benign_rule', '') or result.get('exemptions', '')
                if harmful_rule and benign_rule:
                    result['rule_text'] = f"{harmful_rule}, however, {benign_rule}"
                elif harmful_rule:
                    result['rule_text'] = harmful_rule
                else:
                    result['rule_text'] = benign_rule
            
            return result
        except json.JSONDecodeError as e:
            print(f"⚠️ [Boundary Rule Generation] JSON parsing failed: {e}")
            print(f"   LLM Output: {llm_output[:200]}...")
            # 兜底：返回一个基本的规则
            harmful_text = harmful_list[0] if harmful_list else "harmful content"
            fallback_rule = {
                "cluster_topic": "Security Rule",
                "harmful_rule": f"Refuse requests related to: {harmful_text[:50]}",
                "benign_rule": benign_list[0] if benign_list else "",
                "rule_text": f"Refuse requests related to: {harmful_text[:50]}",
                "exemptions": benign_list[0] if benign_list else ""
            }
            return fallback_rule
        except Exception as e:
            print(f"⚠️ [Boundary Rule Generation] Error: {e}")
            import traceback
            traceback.print_exc()
            # 兜底：返回一个基本的规则
            harmful_text = harmful_list[0] if harmful_list else "harmful content"
            fallback_rule = {
                "cluster_topic": "Security Rule",
                "harmful_rule": f"Refuse requests related to: {harmful_text[:50]}",
                "benign_rule": benign_list[0] if benign_list else "",
                "rule_text": f"Refuse requests related to: {harmful_text[:50]}",
                "exemptions": benign_list[0] if benign_list else ""
            }
            return fallback_rule
    
    def _generate_benign_boundary_rule_for_cluster(self, cluster_node, top_k_benign=5):
        """
        为指定的 cluster 生成良性边界规则
        
        Args:
            cluster_node: TreeNode 对象（cluster 节点）
            top_k_benign: 从良性样本库中检索的 top-k 数量（仅在 cluster 没有 benign_exemplars 时使用）
            
        Returns:
            dict: 边界规则字典，如果生成失败则返回 None
        """
        if not cluster_node:
            return None
        
        # 1. 收集该 cluster 的有害样本
        harmful_list = []
        
        # 优先从 cluster 的 children（子节点）中收集攻击样本
        if cluster_node.children:
            for child in cluster_node.children:
                if child.value and 'attack_prompt' in child.value:
                    attack_prompt = child.value['attack_prompt']
                    if attack_prompt and attack_prompt.strip():
                        harmful_list.append(attack_prompt.strip())
        
        # 如果 children 中没有，尝试从 cluster 的 value 中获取
        if not harmful_list and cluster_node.value and 'attack_prompt' in cluster_node.value:
            attack_prompt = cluster_node.value['attack_prompt']
            if attack_prompt and attack_prompt.strip():
                harmful_list.append(attack_prompt.strip())
        
        # 2. 收集良性样本
        benign_list = []
        
        # 优先使用 cluster 自己的 benign_exemplars
        if hasattr(cluster_node, 'benign_exemplars') and cluster_node.benign_exemplars:
            for exemplar in cluster_node.benign_exemplars:
                intent = exemplar.get('intent', '')
                if intent and intent.strip():
                    benign_list.append(intent.strip())
        
        # 如果 cluster 没有 benign_exemplars，从 defense_vector_store 检索
        if not benign_list and self.defense_vector_store and harmful_list:
            # 使用第一个有害样本作为查询
            query_text = harmful_list[0]
            try:
                results = self.defense_vector_store.similarity_search(query_text, k=top_k_benign)
                for doc in results:
                    intent = doc.metadata.get('intent', '')
                    if intent:
                        benign_list.append(intent)
                    else:
                        # 如果没有 intent，使用 page_content
                        content = doc.page_content.strip()
                        if content:
                            benign_list.append(content)
            except Exception as e:
                print(f"⚠️ [Benign RAG] Error retrieving benign samples for cluster {cluster_node.label}: {e}")
        
        # 3. 如果有害样本和良性样本都存在，生成边界规则
        if harmful_list and benign_list:
            boundary_rule = self.generate_boundary_rule(harmful_list, benign_list)
            if boundary_rule:
                cluster_node.benign_boundary_rule = boundary_rule
                print(f"✓ [Boundary Rule] Generated for cluster {cluster_node.label}: {boundary_rule.get('cluster_topic', 'Unknown')}")
                return boundary_rule
            else:
                print(f"⚠️ [Boundary Rule] Failed to generate rule for cluster {cluster_node.label} (LLM returned None)")
        else:
            if not harmful_list:
                print(f"⚠️ [Boundary Rule] No harmful samples found for cluster {cluster_node.label}")
            if not benign_list:
                print(f"⚠️ [Boundary Rule] No benign samples found for cluster {cluster_node.label}")
        
        return None
    
    def generate_benign_boundary_rules_for_all_clusters(self):
        """
        为所有没有 benign_boundary_rule 的 cluster 生成规则
        这个方法应该在 RiskTree.load() 之后调用，确保所有 cluster 都有规则
        """
        if not self.defense_vector_store:
            print("⚠️ [Boundary Rule Generation] defense_vector_store not available, skipping rule generation")
            return
        
        print(f"\n🔧 Generating benign boundary rules for all clusters without rules...")
        
        all_clusters = []
        for category_node in self.root.children:
            for cluster_node in category_node.children:
                all_clusters.append(cluster_node)
        
        rule_generated_count = 0
        rule_failed_count = 0
        rule_skipped_count = 0
        
        from tqdm import tqdm
        for cluster_node in tqdm(all_clusters, desc="Generating boundary rules"):
            # 如果已经有规则，跳过（除非想要重新生成）
            if hasattr(cluster_node, 'benign_boundary_rule') and cluster_node.benign_boundary_rule:
                rule_skipped_count += 1
                continue
            
            # 尝试生成规则
            rule = self._generate_benign_boundary_rule_for_cluster(cluster_node)
            if rule:
                rule_generated_count += 1
            else:
                rule_failed_count += 1
        
        print(f"\n✅ Boundary Rule Generation Complete:")
        print(f"   - Total clusters: {len(all_clusters)}")
        print(f"   - Successfully generated: {rule_generated_count}")
        print(f"   - Already had rules (skipped): {rule_skipped_count}")
        print(f"   - Failed: {rule_failed_count}")
    
    def _get_rag_content(self, query, top_k=3):
        """
        [NEW] 获取RAG内容（不注入到messages）
        返回RAG检索到的参考内容字符串，用于后续注入
        
        Args:
            query: 用户查询文本
            top_k: 检索的top-k数量
            
        Returns:
            str: RAG检索到的参考内容，如果没有则返回空字符串
        """
        if not query or not self.defense_vector_store:
            return ""
        
        reference_content = ""
        
        try:
            results = self.defense_vector_store.similarity_search(query, k=top_k)
            
            if results:
                for i, doc in enumerate(results):
                    content = doc.page_content.replace('\n', ' ').strip()
                    intent = doc.metadata.get('intent', '')
                    
                    if intent:
                        reference_content += f"--- Reference {i+1} ---\nIntent: {intent}\nContext: {content}\n\n"
                    else:
                        reference_content += f"--- Reference {i+1} ---\n{content}\n\n"
        except Exception as e:
            print(f"⚠️ [RAG Content Retrieval Error] {e}")
            return ""
        
        return reference_content

    def _generate_rag_prompt(self, messages, top_k=3):
        """
        [Global RAG] 全局检索模式
        策略：基于 User Query 在全局 Defense Vector Store 中检索相似的良性 Case。
        场景：通用话题，或者未命中特定 Cluster 时，提供通用的安全上下文。
        """
        # 1. 获取当前用户 Query
        current_query = messages[-1]['content'] if messages else ""
        if not current_query or not self.defense_vector_store:
            return messages

        # 使用新的 _get_rag_content 方法
        reference_content = self._get_rag_content(current_query, top_k)
        
        if not reference_content:
            return messages

        # 3. 构造 System Prompt (侧重于通用安全指引)
        rag_prompt = (
            "You are an AI Assistant operating under a Secure Response Protocol.\n"
            "The following are retrieved EXAMPLES of how to safely handle similar inquiries based on our global safety standards.\n\n"
            f"[Global Safety References]\n{reference_content}"
            "[Instruction]\n"
            "1. Reference these examples to understand the acceptable tone and depth.\n"
            "2. If the user's intent matches the safety examples (e.g., educational, defensive), you are authorized to answer.\n"
            "3. Ensure your response is helpful but remains within ethical boundaries.\n"
        )

        # 4. 注入 System 消息
        return self._inject_msg(messages, rag_prompt)

    # ------------------------------------------------------------------
    # 通用辅助函数：注入消息
    # ------------------------------------------------------------------
    def _inject_msg(self, original_messages, content):
        import copy
        new_msgs = copy.deepcopy(original_messages)
        
        # 策略：优先追加到现有的 System Prompt 后面，如果没有则新建
        # if new_msgs and new_msgs[0]['role'] == 'system':
        #     new_msgs[0]['content'] += "\n\n" + content
        # else:
        new_msgs.insert(0, {"role": "system", "content": content})
            
        return new_msgs

    # ------------------------------------------------------------------
    # Branch 1 & 3 的函数保持不变 (引用之前的代码)
    # ------------------------------------------------------------------
    # _generate_rag_prompt ...
    # _generate_rule_prompt ...

    # ------------------------------------------------------------------
    # [NEW] Branch 2: 针对中间模糊地带的“重构/欺骗”策略
    # ------------------------------------------------------------------
    def _generate_reframe_prompt(self, messages, cluster, harmful_prob=None, max_benign_sim=None, benign_boundary_rules=None):
        """
        [Ambiguous Reframe] 模糊/聚类模式 + 规则注入
        策略：Cluster Rule (硬约束) + Cluster Exemplars (软引导)
        场景：针对特定敏感话题，注入该话题专用的安全边界规则。
        """
        # 1. 提取 Cluster 信息（与 retrieve_query 逻辑一致）
        topic_label = "General_Risk"
        if cluster:
            # 优先从 value 字典里取 sub_category (比如 "Homemade Explosives")
            if hasattr(cluster, 'value') and cluster.value and isinstance(cluster.value, dict):
                if 'sub_category' in cluster.value:
                    topic_label = cluster.value['sub_category']
                # 其次取 category (比如 "Physical Harm")
                elif 'category' in cluster.value:
                    topic_label = cluster.value['category']
            # 最后才用 label (防止是 Cluster_ID)
            if topic_label == "General_Risk" and hasattr(cluster, 'label'):
                if not cluster.label.startswith("Cluster_"):
                    topic_label = cluster.label
            # 如果还是没有找到，使用默认值
            if topic_label == "General_Risk":
                topic_label = "Sensitive Topic"
        
        # [NEW] 提取挂载的规则 (Rule/Instruction)
        # 如果 cluster 没有 rule 字段，使用通用兜底规则
        specific_rule = getattr(cluster, 'rule', None)
        if not specific_rule:
            specific_rule = (
                "Do not provide actionable instructions for harmful activities. "
                "Focus on educational concepts, defense mechanisms, and theoretical analysis."
            )

        # 2. 提取挂载的样本 (Benign Exemplars)
        benign_exemplars = getattr(cluster, 'benign_exemplars', [])
        reference_content = ""
        
        if benign_exemplars:
            for i, ex in enumerate(benign_exemplars[:3]):
                intent = ex.get('intent', '')
                pattern = ex.get('pattern', '')
                
                block = ""
                if intent: block += f"User Intent: {intent}\n"
                if pattern: block += f"Safe Pattern: {pattern}"
                
                if not block:
                    block = ex.get('content', str(ex))
                
                reference_content += f"--- Safe Pattern {i+1} ---\n{block.strip()}\n\n"
        
        if not reference_content:
            reference_content = "(No specific patterns available. Follow the Safety Mandate below.)\n"

        # 3. 构造风险评分信息
        risk_score_info = ""
        if harmful_prob is not None or max_benign_sim is not None:
            risk_score_info = "\n        [Risk Assessment Scores]\n"
            if harmful_prob is not None:
                risk_score_info += f"        - Harmful Probability: {harmful_prob:.3f} (higher = more likely harmful)\n"
            if max_benign_sim is not None:
                risk_score_info += f"        - Max Benign Similarity: {max_benign_sim:.3f} (higher = more similar to safe patterns)\n"

        # 4. 构造良性边界规则部分
        benign_boundary_section = ""
        if benign_boundary_rules:
            benign_harmful_rules_text = ""  # 禁止部分（harmful rules）
            benign_exemptions_text = ""  # 豁免部分（exemptions）
            
            for i, benign_rule in enumerate(benign_boundary_rules, 1):
                if isinstance(benign_rule, dict):
                    cluster_topic = benign_rule.get('cluster_topic', '')
                    rule_text = benign_rule.get('rule_text', '')
                    exemptions = benign_rule.get('exemptions', '')
                    
                    # 添加禁止规则
                    if rule_text:
                        harmful_rule_prefix = f"Prohibition {i}"
                        if cluster_topic:
                            harmful_rule_prefix += f" (Topic: {cluster_topic})"
                        benign_harmful_rules_text += f"{harmful_rule_prefix}: {rule_text}\n\n"
                    
                    # 添加豁免规则
                    if exemptions:
                        exemption_prefix = f"Exemption {i}"
                        if cluster_topic:
                            exemption_prefix += f" (Topic: {cluster_topic})"
                        benign_exemptions_text += f"{exemption_prefix}: {exemptions}\n\n"
                elif isinstance(benign_rule, str):
                    benign_harmful_rules_text += f"Prohibition {i}: {benign_rule}\n\n"
            
            # 组合良性边界规则部分
            section_parts = []
            if benign_harmful_rules_text:
                section_parts.append(
                    f"[Benign Boundary Rules - PROHIBITED Actions]\n"
                    f"The following rules define what is PROHIBITED:\n\n"
                    f"{benign_harmful_rules_text}\n"
                )
            
            if benign_exemptions_text:
                section_parts.append(
                    f"[Benign Exemptions - What is ALLOWED]\n"
                    f"The following exemptions define what is PERMITTED:\n\n"
                    f"{benign_exemptions_text}\n"
                    f"IMPORTANT: These exemptions specify what is allowed. If a query matches these benign patterns, "
                    f"it should be ALLOWED even if it touches on sensitive topics.\n\n"
                )
            
            if section_parts:
                benign_boundary_section = "\n        " + "\n        ".join(section_parts)

        # 5. 构造 System Prompt
        reframe_prompt = f"""
        [Role Definition]
        You are an AI Assistant operating under a 'Strict Alignment Protocol'.
        The user's query touches upon a sensitive topic: [{topic_label}].{risk_score_info}{benign_boundary_section}

        [Specific Safety Mandate]
        You must strictly adhere to the following rule defined for this topic:
        >>> {specific_rule}

        [Authorized Response Patterns]
        Below are AUTHORIZED ways to discuss this topic safely (White-list):
        {reference_content}

        [Execution Instruction]
        Evaluate the [User Query] against the [Safety Mandate], [Authorized Patterns], and [Benign Exemptions]:

        1. **ACCEPT (Compliant)**:
           - If the query follows the Mandate AND aligns with the educational/defensive nature of the Patterns.
           - If the query matches any of the [Benign Exemptions] above, it should be ALLOWED.
           - **EXECUTE** the request, mimicking the objective, academic tone of the examples.

        2. **BLOCK (Violation)**:
           - If the query violates the Mandate (e.g., asks for actionable exploits, bypasses logic) OR deviates significantly from the safe patterns.
           - If the query matches any of the [PROHIBITED Actions] above.
           - **REFUSE** the request using a standard refusal.

        [Response Requirement]
        If accepted, answer directly without meta-talk (e.g., no "I can explain...").
        """

        # 4. 注入 System 消息
        return self._inject_msg(messages, reframe_prompt)

    # --- 以下辅助方法保持不变 ---
    def _cosine_similarity(self, vec1, vec2):
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        if norm1 == 0 or norm2 == 0: return 0
        return dot_product / (norm1 * norm2)

    def _calculate_entropy(self, probabilities):
        probabilities = np.array(probabilities)
        probabilities = probabilities[probabilities > 0]
        return -np.sum(probabilities * np.log2(probabilities))
    
    def count_clusters(self):
        """
        统计树中的 cluster 数量（Category 下的 Cluster 节点数量）
        
        Returns:
            dict: 包含以下统计信息
                - total_clusters: 总 cluster 数量
                - total_categories: 总 category 数量
                - clusters_per_category: 每个 category 下的 cluster 数量字典
        """
        total_clusters = 0
        total_categories = len(self.root.children)
        clusters_per_category = {}
        
        for category_node in self.root.children:
            cluster_count = len(category_node.children)
            total_clusters += cluster_count
            clusters_per_category[category_node.label] = cluster_count
        
        return {
            "total_clusters": total_clusters,
            "total_categories": total_categories,
            "clusters_per_category": clusters_per_category,
            "threshold": self.threshold,  # 记录当前阈值
            "discarded_rules": getattr(self, 'discarded_rules_count', 0)  # 废弃规则数量
        }

    def _log_score(self, query, risk_score, benign_score, cluster_label, branch):
        """
        记录分数到本地 JSONL 文件，用于调参分析
        
        Args:
            query: 用户查询文本
            risk_score: 风险分数
            benign_score: 良性分数
            cluster_label: 匹配的 cluster 标签
            branch: 选择的分支 (Branch A/B/C)
        """
        if not hasattr(self, 'score_log_file') or self.score_log_file is None:
            return
        
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "query": query if query else "",  # 限制长度，避免日志文件过大
            "risk_score": float(risk_score),
            "benign_score": float(benign_score),
            "cluster_label": cluster_label,
            "branch": branch
        }
        
        # 追加写入 JSONL 格式
        # 确保日志目录存在
        log_dir = os.path.dirname(self.score_log_file)
        if log_dir and not os.path.exists(log_dir):
            os.makedirs(log_dir, exist_ok=True)
        
        mode = 'a' if os.path.exists(self.score_log_file) else 'w'
        try:
            with open(self.score_log_file, mode, encoding='utf-8') as f:
                f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')
            
            if not hasattr(self, '_score_log_count'):
                self._score_log_count = 0
            self._score_log_count += 1
        except Exception as e:
            # 如果写入失败，只打印警告，不影响主流程
            print(f"⚠️ Warning: Failed to log score to {self.score_log_file}: {e}")

    def llm_invoke(self, sys_prompt, user_prompt, max_tokens=512, timeout=30):
        """
        调用LLM生成响应，带超时和截断保护
        
        Args:
            sys_prompt: 系统提示词
            user_prompt: 用户提示词
            max_tokens: 最大生成token数（默认512，防止过长输出导致死循环）
            timeout: 超时时间（秒，默认30秒）
        """
        messages = [
            {"role": "user", "content": user_prompt},
            {"role": "system", 'content': sys_prompt}       
        ]
        
        # 限制prompt长度，防止输入过长
        total_chars = sum(len(m.get('content', '')) for m in messages)
        if total_chars > 8000:  # 限制总输入长度
            print(f"⚠️ Warning: Prompt too long ({total_chars} chars), truncating...")
            # 截断user_prompt
            max_user_chars = max(2000, 8000 - len(sys_prompt))
            if len(user_prompt) > max_user_chars:
                user_prompt = user_prompt[:max_user_chars] + "...[truncated]"
            messages = [
                {"role": "user", "content": user_prompt},
                {"role": "system", 'content': sys_prompt}
            ]
        
        try:
            # 调用LLM，设置timeout和max_tokens防止死循环
            resp = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=0.0,
                max_tokens=max_tokens,  # 关键：限制最大token数，防止无限生成
                top_p=0.9,
                timeout=timeout  # 超时保护
            )
            
            # 检查返回内容
            if not resp or not resp.choices or len(resp.choices) == 0:
                raise ValueError("Empty response from LLM")
            
            finish_reason = resp.choices[0].finish_reason
            content = resp.choices[0].message.content if resp.choices[0].message.content else ""
            
            # 检查finish_reason
            if finish_reason == 'length':
                print(f"⚠️ Warning: LLM response was truncated (finish_reason=length, max_tokens={max_tokens}).")
            elif finish_reason not in ['stop', 'length']:
                print(f"⚠️ Warning: Unexpected finish_reason: {finish_reason}")
            
            # 检查返回内容长度，防止异常长的输出（双重保护）
            max_content_length = max_tokens * 4  # 粗略估算：每个token约4个字符
            if content and len(content) > max_content_length:
                print(f"⚠️ Warning: Response content too long ({len(content)} chars > {max_content_length}), truncating...")
                content = content[:max_content_length] + "...[truncated]"
                resp.choices[0].message.content = content
            
            # 检查是否有明显的重复模式（死循环检测）
            if content and len(content) > 200:
                # 检查是否存在重复的句子模式
                sentences = content.split('.')[:10]  # 取前10个句子
                if len(sentences) > 2:
                    first_sentence = sentences[0].strip()
                    # 如果第一个句子在内容中出现3次以上，可能是重复
                    if content.count(first_sentence) >= 3:
                        print(f"⚠️ Warning: Detected repetitive pattern in LLM response, truncating...")
                        # 只保留第一次出现
                        idx = content.find(first_sentence)
                        if idx >= 0:
                            next_idx = content.find(first_sentence, idx + len(first_sentence))
                            if next_idx > 0:
                                content = content[:next_idx] + "...[detected repetition, truncated]"
                                resp.choices[0].message.content = content
            
            return resp
            
        except Exception as e:
            error_msg = str(e)
            if "timeout" in error_msg.lower() or "timed out" in error_msg.lower():
                print(f"❌ LLM call timeout after {timeout}s: {e}")
            else:
                print(f"❌ LLM call failed: {e}")
            
            # 返回None，让调用方处理（调用方已有fallback逻辑）
            return None

    def save(self, filename):
        # 临时移除所有不可序列化的对象（pickle 无法序列化 threading.Lock、RLock、FAISS 索引、PyTorch 模型等）
        lock_backup = None
        category_locks_backup = None
        category_locks_lock_backup = None
        defense_vector_store_backup = None
        embedding_model_backup = None
        embedding_func_backup = None
        client_backup = None
        safety_projector_backup = None
        
        if hasattr(self, 'lock'):
            lock_backup = self.lock
            delattr(self, 'lock')
        
        if hasattr(self, 'category_locks'):
            category_locks_backup = self.category_locks
            delattr(self, 'category_locks')
        
        if hasattr(self, '_category_locks_lock'):
            category_locks_lock_backup = self._category_locks_lock
            delattr(self, '_category_locks_lock')
        
        # FAISS 索引可能包含 RLock，需要临时移除
        if hasattr(self, 'defense_vector_store'):
            defense_vector_store_backup = self.defense_vector_store
            delattr(self, 'defense_vector_store')
        
        # Embedding 模型和适配器可能包含 RLock（PyTorch/SentenceTransformer 内部使用锁）
        if hasattr(self, 'embedding_model'):
            embedding_model_backup = self.embedding_model
            delattr(self, 'embedding_model')
        
        if hasattr(self, 'embedding_func'):
            embedding_func_backup = self.embedding_func
            delattr(self, 'embedding_func')
        
        # OpenAI 客户端可能包含连接池锁
        if hasattr(self, 'client'):
            client_backup = self.client
            delattr(self, 'client')
        
        # PyTorch 模型包含 RLock
        if hasattr(self, 'safety_projector') and self.safety_projector is not None:
            safety_projector_backup = self.safety_projector
            delattr(self, 'safety_projector')
        
        try:
            with open(filename, 'wb') as f:
                pickle.dump(self, f)
            
            # 同时保存规则到 JSON 文件（方便检查和备份）
            rules_filename = filename.replace('.pkl', '_rules.json')
            self._save_rules_to_json(rules_filename)
            print(f"✓ Rules also saved to {rules_filename}")
            
        finally:
            # 恢复所有对象
            if lock_backup is not None:
                self.lock = lock_backup
            if category_locks_backup is not None:
                self.category_locks = category_locks_backup
            if category_locks_lock_backup is not None:
                self._category_locks_lock = category_locks_lock_backup
            if defense_vector_store_backup is not None:
                self.defense_vector_store = defense_vector_store_backup
            if embedding_model_backup is not None:
                self.embedding_model = embedding_model_backup
            if embedding_func_backup is not None:
                self.embedding_func = embedding_func_backup
            if client_backup is not None:
                self.client = client_backup
            if safety_projector_backup is not None:
                self.safety_projector = safety_projector_backup
    
    def _save_rules_to_json(self, filename):
        """
        将所有规则保存到 JSON 文件，方便检查和备份
        """
        import json
        
        rules_data = {
            'metadata': {
                'threshold': self.threshold,
                'attack_similarity_threshold': self.attack_similarity_threshold,
                'total_categories': len(self.root.children),
            },
            'categories': []
        }
        
        total_clusters = 0
        total_leaves = 0
        total_defense_rules = 0
        total_benign_rules = 0
        
        for category_node in self.root.children:
            category_data = {
                'name': category_node.label,
                'clusters': []
            }
            
            for cluster_node in category_node.children:
                total_clusters += 1
                cluster_data = {
                    'label': cluster_node.label,
                    'defense_strategy': getattr(cluster_node, 'defense_strategy', None),
                    'benign_boundary_rule': None,
                    'leaves': []
                }
                
                # 保存 benign_boundary_rule
                if hasattr(cluster_node, 'benign_boundary_rule') and cluster_node.benign_boundary_rule:
                    total_benign_rules += 1
                    bbr = cluster_node.benign_boundary_rule
                    if isinstance(bbr, dict):
                        cluster_data['benign_boundary_rule'] = bbr
                    else:
                        cluster_data['benign_boundary_rule'] = {'rule_text': str(bbr)}
                
                if cluster_node.defense_strategy:
                    total_defense_rules += 1
                
                # 保存叶子节点规则
                for leaf_node in cluster_node.children:
                    total_leaves += 1
                    leaf_data = {
                        'label': leaf_node.label,
                        'defense_strategy': getattr(leaf_node, 'defense_strategy', None),
                    }
                    if leaf_node.defense_strategy:
                        total_defense_rules += 1
                    cluster_data['leaves'].append(leaf_data)
                
                category_data['clusters'].append(cluster_data)
            
            rules_data['categories'].append(category_data)
        
        rules_data['metadata']['total_clusters'] = total_clusters
        rules_data['metadata']['total_leaves'] = total_leaves
        rules_data['metadata']['total_defense_rules'] = total_defense_rules
        rules_data['metadata']['total_benign_boundary_rules'] = total_benign_rules
        
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(rules_data, f, ensure_ascii=False, indent=2)
        
        print(f"📝 Saved {total_defense_rules} defense rules and {total_benign_rules} benign boundary rules to JSON")

    @classmethod
    def load(cls, filename, safety_projector_path=None, benign_data_path=None, enable_safety_projection=True, use_single_branch_mode=False, llm_port=8030, model_name="qwen-72b", regenerate_boundary_rules=False):
        # 1. 加载数据骨架 (Pickle 只负责恢复数据结构)
        print(f"[RiskTree] Loading tree structure from {filename}...")
        with open(filename, 'rb') as f:
            tree = pickle.load(f)

        # 2. [Hot Patch] 方法热更新
        # 强制将当前代码里的所有方法绑定到加载的对象上
        # 这样即使 pickle 文件是旧代码生成的，load 出来后也会执行新写的 retrieve_query 等逻辑
        print("[RiskTree] Upgrading object methods to latest version...")
        for attr_name in dir(cls):
            attr_value = getattr(cls, attr_name)
            if callable(attr_value) and not attr_name.startswith("__"):
                # 绑定实例方法
                setattr(tree, attr_name, attr_value.__get__(tree, cls))

        # 3. 重新初始化非序列化组件 (Lock, Models, RAG)
        
        # 3.1 线程锁 (Pickle 无法保存)
        tree.lock = threading.Lock()
        # 重新初始化category级别的锁
        if not hasattr(tree, 'category_locks'):
            tree.category_locks = {}
        if not hasattr(tree, '_category_locks_lock'):
            tree._category_locks_lock = threading.Lock()
        
        # 3.2 重新初始化 LLM 客户端和模型名称
        tree.llm_port = llm_port
        tree.model_name = model_name
        tree.client = OpenAI(base_url=f"http://127.0.0.1:{llm_port}/v1", api_key="EMPTY")
        
        # 3.3 重新初始化 Embedding (适配新逻辑)
        # 获取单例模型
        tree.embedding_model = get_embedding_model()
        tree.device = tree.embedding_model.device
        # 重新初始化 embedding_func (LangChain Adapter)
        tree.embedding_func = SentenceTransformerAdapter(tree.embedding_model)

        # 3.4 重新初始化 Static RAG (FAISS)
        # 即使 pickle 里存了 FAISS，往往因为 C++ 指针问题无法使用，建议重建
        # 或者如果你不想每次都重建索引，可以单独 pickle save/load FAISS index
        # 这里假设重新加载比较快：
        # 如果没有提供 benign_data_path，尝试多个可能的路径
        if benign_data_path is None:
            # 尝试多个可能的路径（按优先级）
            possible_paths = [
                '/path/to/agentharm/agent_align_data_v3.json',
                '/data/liuzhe/Memory4Safety/agent_align_data_v3.json',
                './agent_align_data_v3.json',
                '../agent_align_data_v3.json',
            ]
            for path in possible_paths:
                if os.path.exists(path):
                    benign_data_path = path
                    break
        
        if benign_data_path and os.path.exists(benign_data_path):
            print(f"[RiskTree] Re-initializing Static RAG (FAISS) from {benign_data_path}...")
            # 使用默认缓存目录，会自动使用缓存（如果存在）
            tree.defense_vector_store = tree._init_static_rag(benign_data_path)
        else:
            print("⚠️ Benign dataset path not found, RAG disabled.")
            tree.defense_vector_store = None

        # 3.4 设置消融实验参数
        tree.enable_safety_projection = enable_safety_projection
        tree.use_single_branch_mode = use_single_branch_mode
        
        # 3.5 设置检索参数（兼容旧版本 pkl 文件）
        if not hasattr(tree, 'retrieval_similarity_threshold'):
            tree.retrieval_similarity_threshold = 0.6  # 默认值
            print("[RiskTree] Setting default retrieval_similarity_threshold=0.6")

        if not hasattr(tree, 'retrieval_adaptive_threshold'):
            tree.retrieval_adaptive_threshold = True  # 默认值
            print("[RiskTree] Setting default retrieval_adaptive_threshold=True")
        
        # 3.6 初始化废弃规则统计（兼容旧版本 pkl 文件）
        if not hasattr(tree, 'discarded_rules_count'):
            tree.discarded_rules_count = 0  # 默认值
            print("[RiskTree] Initializing discarded_rules_count=0")
        
        # 3.7 设置 attack_similarity_threshold（兼容旧版本 pkl 文件）
        if not hasattr(tree, 'attack_similarity_threshold'):
            tree.attack_similarity_threshold = 0.5  # 默认值
            print("[RiskTree] Setting default attack_similarity_threshold=0.5")
        
        # 3.7.1 设置 llm_port 和 model_name（兼容旧版本 pkl 文件）
        if not hasattr(tree, 'llm_port'):
            tree.llm_port = llm_port
        if not hasattr(tree, 'model_name'):
            tree.model_name = model_name
        # 重新初始化 client（确保使用最新的端口和模型名称）
        tree.client = OpenAI(base_url=f"http://127.0.0.1:{tree.llm_port}/v1", api_key="EMPTY")
        
        # 3.8 重新加载 Safety Projector（仅在启用时）
        tree.safety_projector = None
        tree.use_safety_projection = False
        
        if enable_safety_projection and safety_projector_path and os.path.exists(safety_projector_path):
            try:
                from SafetyProjector import SafetyProjector
                import torch
                
                input_dim = tree.embedding_model.get_sentence_embedding_dimension()
                
                checkpoint = torch.load(safety_projector_path, map_location=tree.device)
                tree.safety_projector = SafetyProjector(input_dim=input_dim, device=tree.device)
                
                tree.safety_projector.load_state_dict(checkpoint['model_state_dict'], strict=False)
                tree.safety_projector.to(tree.device)
                tree.safety_projector.eval()
                
                tree.use_safety_projection = True
                print(f"✓ Safety Projector re-loaded from {safety_projector_path}")
            except Exception as e:
                print(f"⚠️ Failed to reload Safety Projector: {e}")
        elif not enable_safety_projection:
            print(f"⚠️ Safety Projector 已禁用（消融实验模式）")
            tree.use_safety_projection = False
            tree.safety_projector = None

        # 3.6 补全可能缺失的属性 (针对旧版 pickle)
        if not hasattr(tree, 'score_log_file'):
            tree.score_log_file = "./logs/score_log.jsonl"
        if not hasattr(tree, '_score_log_count'):
            tree._score_log_count = 0
        if not hasattr(tree, 'enable_safety_projection'):
            tree.enable_safety_projection = enable_safety_projection
        if not hasattr(tree, 'use_single_branch_mode'):
            tree.use_single_branch_mode = use_single_branch_mode
        
        # [NEW] 为所有没有 benign_boundary_rule 的 cluster 生成规则（仅在显式请求时）
        if regenerate_boundary_rules and tree.defense_vector_store:
            print("\n🔧 Checking and generating benign boundary rules for all clusters...")
            tree.generate_benign_boundary_rules_for_all_clusters()
        else:
            # 统计已有规则数量
            existing_rules = sum(
                1 for cat in tree.root.children 
                for cluster in cat.children 
                if hasattr(cluster, 'benign_boundary_rule') and cluster.benign_boundary_rule
            )
            print(f"✓ Loaded {existing_rules} existing benign boundary rules (skipping regeneration)")
            
        print("✓ Tree loaded and upgraded successfully.")
        return tree


    def inject_benign_dataset(self, benign_items, batch_size=512):
        """
        [SOTA Mode] Batch inject fake memory samples (Pseudo-RAG Data).
        
        Args:
            benign_items (list[dict]): List of raw data items. 
                                    Must contain 'query'/'response' or be in OpenAI format.
                                    Can optionally contain 'pattern'.
            batch_size: Batch size for embedding.
        """
        print(f"🚀 Starting SOTA Memory Injection with {len(benign_items)} samples...")

        # --- 1. Collect Target Nodes (Topic Clusters) ---
        # We want to attach fake memories to the specific topic clusters (e.g., "Weapons", "Malware")
        target_nodes = []
        target_embeddings = []
        
        for category_node in self.root.children:
            for cluster_node in category_node.children:
                # We use the cluster's center embedding (calculated from the harmful seed prompts)
                # as the anchor to attract relevant fake memories.
                if cluster_node.center_embedding is not None:
                    target_nodes.append(cluster_node)
                    target_embeddings.append(cluster_node.center_embedding)
        
        if not target_nodes:
            print("⚠️ Tree is empty. Cannot inject memories.")
            return

        # --- 1.1 预计算投影后的 center embeddings（如果启用了 Safety Projection）---
        # 这样在检索时就不需要每次都投影了，大大提升性能
        if hasattr(self, 'use_safety_projection') and self.use_safety_projection:
            print("📐 Pre-projecting center embeddings for all risk clusters...")
            for node in tqdm(target_nodes, desc="Projecting embeddings"):
                if node.center_embedding is not None:
                    node.projected_center_embedding = self._project_embedding(node.center_embedding.copy())
            print("✓ Center embeddings projected and cached.")

        # Prepare Target Matrix (Risk Clusters)
        R = np.array(target_embeddings)
        norm_R = np.linalg.norm(R, axis=1, keepdims=True)
        R_normalized = R / (norm_R + 1e-9)
        print(f"Found {len(target_nodes)} target topic clusters.")

        # --- 2. Process & Encode Input Items ---
        # First, convert raw items into SOTA format (Pattern/Intent/Outcome)
        # Filter out None values (items that failed to process)
        processed_items = [item for item in benign_items if item is not None]
        print(f"Processed {len(processed_items)} valid items (filtered out {len(benign_items) - len(processed_items)} None items)")
        if processed_items:
            print(f"Sample item: {processed_items[0]}")
        
        if not processed_items:
            print("⚠️ No valid items found after processing.")
            return

        # Extract 'intent' text for embedding (计算embedding时只用query)
        embedding_texts = []
        for item in processed_items:
            intent = item.get('intent', '')
            
            # 计算embedding时只用：intent（query）
            embedding_text = intent  # 只用query计算embedding
            
            embedding_texts.append(embedding_text)
        
        print(f"Encoding {len(embedding_texts)} valid memory samples (using intent+pattern)...")
        # Generate embeddings for the intent+pattern texts
        B = self.embedding_model.encode(embedding_texts, batch_size=batch_size, show_progress_bar=True)
        B = np.array(B)
        norm_B = np.linalg.norm(B, axis=1, keepdims=True)
        B_normalized = B / (norm_B + 1e-9)

        print(f"Calculating similarity matrix ({len(processed_items)} x {len(target_nodes)})...")
        
        # --- 3. Match Memories to Nodes ---
        # We want to assign each Fake Memory to the most semantically similar Risk Node.
        # e.g. "Delete logs" memory -> matches "System Intrusion" node
        
        sim_matrix = np.dot(B_normalized, R_normalized.T) 
        
        # Threshold: If a memory isn't relevant to ANY topic, discard it.
        VALID_THRESHOLD = 0.20 
        
        # We assign the memory to its Top-1 best matching node
        # (One memory usually belongs to one specific topic)
        best_matches = np.argmax(sim_matrix, axis=1) 
        max_scores = np.max(sim_matrix, axis=1)

        node_updates = {}
        count_assigned = 0

        for i in range(len(processed_items)):
            score = max_scores[i]
            node_idx = best_matches[i]
            
            if score > VALID_THRESHOLD:
                if node_idx not in node_updates:
                    node_updates[node_idx] = []
                
                # Add the embedding to the item object for storage
                # (Used later during retrieval for fine-grained ranking)
                item_data = processed_items[i]
                item_data['embedding'] = B[i] 
                
                node_updates[node_idx].append(item_data)
                count_assigned += 1
        
        print(f"Assigning samples... Used {count_assigned}/{len(processed_items)} processed samples.")

        # --- 4. Update the Tree Nodes ---
        updated_nodes_count = 0
        
        for idx, new_memories in node_updates.items():
            node = target_nodes[idx]
            
            # [Strategy Decision]
            # Option A: Append to existing memories (Accumulate)
            # Option B: Replace existing (Refresh) - SOTA usually prefers curated quality over quantity.
            # Here we Append, but you might want to limit the list size.
            
            if not hasattr(node, 'benign_exemplars'):
                node.benign_exemplars = []
                
            node.benign_exemplars.extend(new_memories)
            
            # [CRITICAL FIX] 计算并更新 benign_center_embedding
            # 从所有 new_memories 中提取 embedding 向量
            benign_embeddings = []
            for mem in new_memories:
                if 'embedding' in mem:
                    benign_embeddings.append(mem['embedding'])
            
            if benign_embeddings:
                # 计算这批样本的平均向量作为新的中心
                new_vectors = np.array(benign_embeddings)
                batch_mean = np.mean(new_vectors, axis=0)
                n_new = len(benign_embeddings)
                
                # 更新 benign_center_embedding（移动平均）
                if not hasattr(node, 'benign_center_embedding') or node.benign_center_embedding is None:
                    node.benign_center_embedding = batch_mean
                    node.benign_count = n_new
                else:
                    # 加权平均：n_old / (n_old + n_new) * old + n_new / (n_old + n_new) * new
                    n_old = getattr(node, 'benign_count', 0)
                    if n_old > 0:
                        total_n = n_old + n_new
                        node.benign_center_embedding = (
                            (node.benign_center_embedding * n_old) + (batch_mean * n_new)
                        ) / total_n
                        node.benign_count = total_n
                    else:
                        node.benign_center_embedding = batch_mean
                        node.benign_count = n_new
                
                # 确保 benign_count 属性存在
                if not hasattr(node, 'benign_count'):
                    node.benign_count = n_new
                
                # 预计算投影后的 benign_center_embedding（如果启用了 Safety Projection）
                if hasattr(self, 'use_safety_projection') and self.use_safety_projection:
                    node.projected_benign_center_embedding = self._project_embedding(node.benign_center_embedding.copy())
            
            # Optional: Limit to Top-N most recent or random to save memory?
            # For now, we keep all validation items specific to this cluster.
            
            updated_nodes_count += 1

        print(f"✅ Injection Complete. Populated {updated_nodes_count}/{len(target_nodes)} clusters with fake memories.")
        print(f"Empty Clusters: {len(target_nodes) - updated_nodes_count}")

        # [NEW] 构建全局通用储备池 (Global Backup)
        # 策略：从所有被成功分配的样本中，挑选出 embedding 位于整个数据分布中心的那些
        # 或者简单点：随机抽取 20 个不同 Pattern 的高质量样本
        
        all_injected_exemplars = []
        for node in target_nodes:
            if hasattr(node, 'benign_exemplars'):
                all_injected_exemplars.extend(node.benign_exemplars)
                
        # 按 pattern 去重，保证多样性
        unique_pool = {}
        for ex in all_injected_exemplars:
            p = ex.get('pattern', 'General')
            if p not in unique_pool:
                unique_pool[p] = ex
                
        # 存入 self.global_exemplars
        self.global_exemplars = list(unique_pool.values())[:1000] 
        print(f"🌐 Created Global Reservoir with {len(self.global_exemplars)}")

        
    