#!/usr/bin/env python3
"""
Flattened Memory Tree - 规则拉平的 Memory Tree 封装
用于消融实验：将所有规则拉平，通过 RAG 检索
"""

import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from risk_tree import RiskTree, TreeNode


class FlattenedMemoryTree:
    """
    封装 RiskTree，将所有规则拉平，通过 RAG 检索
    """
    
    def __init__(self, risk_tree: RiskTree):
        """
        初始化 Flattened Memory Tree
        
        Args:
            risk_tree: RiskTree 实例
        """
        self.risk_tree = risk_tree
        self.flattened_rules = []  # 存储所有拉平的规则
        self.rule_embeddings = []  # 存储规则的 embedding
        self._flatten_rules()
    
    def _flatten_rules(self):
        """
        遍历整个树，将所有规则拉平
        收集所有 cluster 节点的 defense_strategy 和 benign_boundary_rule
        """
        print("[Flattened Memory Tree] 开始拉平所有规则...")
        self.flattened_rules = []
        self.rule_embeddings = []
        
        # 遍历所有 category -> cluster 节点
        for category_node in self.risk_tree.root.children:
            category_name = category_node.label
            
            for cluster_node in category_node.children:
                # 收集 defense_strategy（动态规则）
                defense_strategy = None
                if hasattr(cluster_node, 'defense_strategy') and cluster_node.defense_strategy:
                    defense_strategy = cluster_node.defense_strategy
                elif cluster_node.value and 'defense_response' in cluster_node.value:
                    defense_strategy = cluster_node.value['defense_response']
                
                # 收集 benign_boundary_rule（良性边界规则）
                benign_boundary_rule = None
                if hasattr(cluster_node, 'benign_boundary_rule') and cluster_node.benign_boundary_rule:
                    benign_boundary_rule = cluster_node.benign_boundary_rule
                
                # 获取 topic label
                topic_label = "General_Risk"
                if cluster_node.value:
                    if 'sub_category' in cluster_node.value:
                        topic_label = cluster_node.value['sub_category']
                    elif 'category' in cluster_node.value:
                        topic_label = cluster_node.value['category']
                    elif not cluster_node.label.startswith("Cluster_"):
                        topic_label = cluster_node.label
                
                # 获取 center embedding（用于相似度计算）
                center_embedding = None
                if hasattr(cluster_node, 'projected_center_embedding') and cluster_node.projected_center_embedding is not None:
                    center_embedding = cluster_node.projected_center_embedding
                elif hasattr(cluster_node, 'center_embedding') and cluster_node.center_embedding is not None:
                    center_embedding = cluster_node.center_embedding
                
                # 如果有 defense_strategy，添加到拉平列表
                if defense_strategy and defense_strategy.strip():
                    rule_entry = {
                        'type': 'defense_strategy',
                        'rule': defense_strategy.strip(),
                        'category': category_name,
                        'topic_label': topic_label,
                        'cluster_label': cluster_node.label,
                        'center_embedding': center_embedding,
                        'benign_boundary_rule': benign_boundary_rule
                    }
                    self.flattened_rules.append(rule_entry)
                    if center_embedding is not None:
                        self.rule_embeddings.append(center_embedding)
                    else:
                        # 如果没有 center_embedding，使用 rule 文本的 embedding
                        try:
                            rule_emb = self.risk_tree.embedding_model.encode([defense_strategy])[0]
                            self.rule_embeddings.append(rule_emb)
                        except:
                            self.rule_embeddings.append(None)
                
                # 如果有 benign_boundary_rule，也添加到拉平列表
                if benign_boundary_rule:
                    if isinstance(benign_boundary_rule, dict):
                        rule_text = benign_boundary_rule.get('rule_text', '')
                        exemptions = benign_boundary_rule.get('exemptions', '')
                        cluster_topic = benign_boundary_rule.get('cluster_topic', topic_label)
                        
                        if rule_text and rule_text.strip():
                            rule_entry = {
                                'type': 'benign_boundary_rule',
                                'rule': rule_text.strip(),
                                'exemptions': exemptions,
                                'category': category_name,
                                'topic_label': cluster_topic,
                                'cluster_label': cluster_node.label,
                                'center_embedding': center_embedding,
                                'benign_boundary_rule': benign_boundary_rule
                            }
                            self.flattened_rules.append(rule_entry)
                            if center_embedding is not None:
                                self.rule_embeddings.append(center_embedding)
                            else:
                                try:
                                    rule_emb = self.risk_tree.embedding_model.encode([rule_text])[0]
                                    self.rule_embeddings.append(rule_emb)
                                except:
                                    self.rule_embeddings.append(None)
                    elif isinstance(benign_boundary_rule, str) and benign_boundary_rule.strip():
                        rule_entry = {
                            'type': 'benign_boundary_rule',
                            'rule': benign_boundary_rule.strip(),
                            'exemptions': '',
                            'category': category_name,
                            'topic_label': topic_label,
                            'cluster_label': cluster_node.label,
                            'center_embedding': center_embedding,
                            'benign_boundary_rule': benign_boundary_rule
                        }
                        self.flattened_rules.append(rule_entry)
                        if center_embedding is not None:
                            self.rule_embeddings.append(center_embedding)
                        else:
                            try:
                                rule_emb = self.risk_tree.embedding_model.encode([benign_boundary_rule])[0]
                                self.rule_embeddings.append(rule_emb)
                            except:
                                self.rule_embeddings.append(None)
        
        print(f"[Flattened Memory Tree] 规则拉平完成，共 {len(self.flattened_rules)} 条规则")
    
    def retrieve_query(self, messages, top_k=3, prompt=False):
        """
        通过 RAG 检索拉平的规则
        
        Args:
            messages: 用户消息列表
            top_k: 检索的 top-k 数量
            prompt: 如果为True，返回message列表；如果为False，返回字典
            
        Returns:
            如果 prompt=False:
                dict: 包含以下字段
                    - is_harmful: bool - 是否应该拒绝
                    - rag_content: str - RAG检索到的内容
                    - dynamic_rule: str - 动态规则（第一条）
                    - dynamic_rules: list - 所有检索到的动态规则
                    - benign_boundary_rules: list - 检索到的良性边界规则
                    - branch: str - 分支类型
                    - original_messages: list - 原始messages
                    - topic_label: str - 主题标签
                    - harmful_prob: float - 有害概率
                    - max_benign_sim: float - 最大良性相似度
            如果 prompt=True:
                list: 已将RAG内容注入的messages列表
        """
        # 提取用户查询
        if not messages or not isinstance(messages, list):
            if prompt:
                return messages if messages else []
            else:
                return {
                    "is_harmful": False,
                    "rag_content": "",
                    "dynamic_rule": "",
                    "dynamic_rules": [],
                    "benign_boundary_rules": [],
                    "branch": "SAFE",
                    "original_messages": messages if messages else [],
                    "topic_label": "General",
                    "harmful_prob": 0.0,
                    "max_benign_sim": 0.0
                }
        
        user_query = ""
        for m in reversed(messages):
            if m.get('role') == 'user':
                user_query = m.get('content', "")
                break
        
        if not user_query:
            if prompt:
                return messages
            else:
                return {
                    "is_harmful": False,
                    "rag_content": "",
                    "dynamic_rule": "",
                    "dynamic_rules": [],
                    "benign_boundary_rules": [],
                    "branch": "SAFE",
                    "original_messages": messages,
                    "topic_label": "General",
                    "harmful_prob": 0.0,
                    "max_benign_sim": 0.0
                }
        
        # 计算查询的 embedding
        try:
            query_embedding = self.risk_tree.embedding_model.encode([user_query])[0]
        except Exception as e:
            print(f"⚠️ [Flattened Memory Tree] 编码查询失败: {e}")
            if prompt:
                return messages
            else:
                return {
                    "is_harmful": False,
                    "rag_content": "",
                    "dynamic_rule": "",
                    "dynamic_rules": [],
                    "benign_boundary_rules": [],
                    "branch": "SAFE",
                    "original_messages": messages,
                    "topic_label": "General",
                    "harmful_prob": 0.0,
                    "max_benign_sim": 0.0
                }
        
        # 计算与所有规则的相似度
        similarities = []
        for i, rule_emb in enumerate(self.rule_embeddings):
            if rule_emb is None:
                similarities.append(-1.0)
                continue
            
            try:
                # 计算余弦相似度
                if hasattr(self.risk_tree, '_cosine_similarity'):
                    sim = self.risk_tree._cosine_similarity(query_embedding, rule_emb)
                else:
                    # 如果没有 _cosine_similarity 方法，手动计算
                    dot_product = np.dot(query_embedding, rule_emb)
                    norm_query = np.linalg.norm(query_embedding)
                    norm_rule = np.linalg.norm(rule_emb)
                    if norm_query > 0 and norm_rule > 0:
                        sim = dot_product / (norm_query * norm_rule)
                    else:
                        sim = 0.0
                similarities.append(sim)
            except Exception as e:
                print(f"⚠️ [Flattened Memory Tree] 计算相似度失败 (rule {i}): {e}")
                similarities.append(-1.0)
        
        # 获取 top-k 规则
        top_k_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
        
        # 分离动态规则和良性边界规则
        dynamic_rules = []
        benign_boundary_rules = []
        topic_labels = []
        
        for idx in top_k_indices:
            if similarities[idx] < 0:
                continue
            
            rule_entry = self.flattened_rules[idx]
            sim_score = similarities[idx]
            
            if rule_entry['type'] == 'defense_strategy':
                dynamic_rules.append(rule_entry['rule'])
                topic_labels.append(rule_entry['topic_label'])
            elif rule_entry['type'] == 'benign_boundary_rule':
                benign_rule_dict = rule_entry.get('benign_boundary_rule')
                if benign_rule_dict:
                    benign_boundary_rules.append(benign_rule_dict)
                else:
                    # 如果没有完整字典，构造一个
                    benign_boundary_rules.append({
                        'rule_text': rule_entry['rule'],
                        'exemptions': rule_entry.get('exemptions', ''),
                        'cluster_topic': rule_entry['topic_label']
                    })
                topic_labels.append(rule_entry['topic_label'])
        
        # 确定主要 topic_label（使用相似度最高的）
        topic_label = topic_labels[0] if topic_labels else "General_Risk"
        
        # 获取 RAG 内容（使用原始的 _get_rag_content 方法）
        rag_content = ""
        if hasattr(self.risk_tree, '_get_rag_content'):
            rag_content = self.risk_tree._get_rag_content(user_query, top_k=top_k)
        
        # 计算有害概率和良性相似度（简化版本，使用最高相似度）
        max_sim = max(similarities) if similarities else 0.0
        harmful_prob = 1.0 - max_sim if max_sim < 1.0 else 0.5  # 简化计算
        max_benign_sim = max_sim  # 简化处理
        
        # 确定分支（简化版本）
        TH_HARM_HIGH = 0.92
        TH_HARM_LOW = 0.50
        TH_BENIGN_HIGH = 0.65
        
        if harmful_prob < TH_HARM_LOW and max_benign_sim > TH_BENIGN_HIGH:
            branch = "SAFE"
            is_harmful = False
            dynamic_rules = []  # SAFE 分支不需要规则
        elif harmful_prob > TH_HARM_HIGH:
            branch = "BLOCK"
            is_harmful = True
        else:
            branch = "AMBIGUOUS"
            is_harmful = False
        
        # 提取第一条规则作为 dynamic_rule（向后兼容）
        dynamic_rule = dynamic_rules[0] if dynamic_rules else ""
        
        # 根据 prompt 参数返回不同格式
        if prompt:
            # 返回注入 RAG 内容的 messages
            return self.risk_tree._generate_rag_prompt(messages, top_k=top_k)
        else:
            # 返回字典格式
            return {
                "is_harmful": is_harmful,
                "rag_content": rag_content,
                "dynamic_rule": dynamic_rule,
                "dynamic_rules": dynamic_rules,
                "benign_boundary_rules": benign_boundary_rules,
                "branch": branch,
                "original_messages": messages,
                "topic_label": "General_Risk",
                "harmful_prob": harmful_prob,
                "max_benign_sim": max_benign_sim
            }
