#!/usr/bin/env python3
"""
RiskTree Threshold 敏感度分析脚本

分析不同的 threshold 值对簇的大小和检索精度的影响。
threshold 决定了何时分裂新节点：entropy > threshold 时会创建新簇。
"""

import json
import sys
import os
import pickle
from pathlib import Path
from collections import defaultdict
import numpy as np
from tqdm import tqdm

# 添加 src 目录到路径
src_path = Path(__file__).parent / 'src'
sys.path.insert(0, str(src_path))

from risk_tree import RiskTree


def count_clusters(tree):
    """统计树中的簇数量"""
    cluster_count = 0
    cluster_sizes = []
    
    for category_node in tree.root.children:
        for cluster_node in category_node.children:
            cluster_count += 1
            # 统计簇中的样本数量（children 数量）
            sample_count = len(cluster_node.children)
            cluster_sizes.append(sample_count)
    
    return cluster_count, cluster_sizes


def evaluate_retrieval_accuracy(tree, test_queries, top_k=3):
    """
    评估检索精度
    
    Args:
        tree: RiskTree 实例
        test_queries: 测试查询列表，每个查询包含 user_query, label, branch, topic_label 等
        top_k: 检索的 top-k 数量
    
    Returns:
        dict: 包含准确率、召回率等指标
    """
    correct_branch = 0
    correct_topic = 0
    total = len(test_queries)
    
    branch_confusion = defaultdict(lambda: defaultdict(int))
    topic_matches = []
    
    for test_case in tqdm(test_queries, desc="评估检索精度"):
        user_query = test_case['user_query']
        expected_label = test_case.get('label', '')
        expected_branch = test_case.get('branch', '')
        expected_topic = test_case.get('topic_label', '')
        
        try:
            messages = [{"role": "user", "content": user_query}]
            result = tree.retrieve_query(messages, prompt=False, top_k=top_k)
            
            predicted_branch = result.get('branch', 'UNKNOWN')
            predicted_topic = result.get('topic_label', 'Unknown')
            
            # 统计分支准确率
            if predicted_branch == expected_branch:
                correct_branch += 1
            branch_confusion[expected_branch][predicted_branch] += 1
            
            # 统计主题匹配（模糊匹配）
            if expected_topic and predicted_topic:
                # 简单的字符串匹配（可以改进为更复杂的语义匹配）
                if expected_topic.lower() in predicted_topic.lower() or predicted_topic.lower() in expected_topic.lower():
                    correct_topic += 1
                    topic_matches.append(True)
                else:
                    topic_matches.append(False)
            else:
                topic_matches.append(False)
                
        except Exception as e:
            print(f"⚠️ 查询处理失败: {user_query[:50]}... 错误: {e}")
            continue
    
    branch_accuracy = correct_branch / total if total > 0 else 0.0
    topic_accuracy = correct_topic / total if total > 0 else 0.0
    
    return {
        'branch_accuracy': branch_accuracy,
        'topic_accuracy': topic_accuracy,
        'total_queries': total,
        'branch_confusion': dict(branch_confusion),
        'topic_matches': topic_matches
    }


def analyze_threshold_sensitivity(
    tree_path,
    test_queries_path,
    threshold_values,
    safety_projector_path=None,
    benign_data_path=None,
    output_dir="./sensitivity_results/threshold"
):
    """
    分析 threshold 参数的敏感度
    
    Args:
        tree_path: RiskTree pickle 文件路径
        test_queries_path: 测试查询 JSON 文件路径
        threshold_values: threshold 值列表，例如 [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
        safety_projector_path: Safety Projector 模型路径
        benign_data_path: 良性数据路径
        output_dir: 结果输出目录
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. 加载测试查询
    print(f"📋 加载测试查询: {test_queries_path}")
    with open(test_queries_path, 'r', encoding='utf-8') as f:
        test_queries = json.load(f)
    print(f"✅ 加载了 {len(test_queries)} 条测试查询")
    
    # 限制测试查询数量（如果太多）
    if len(test_queries) > 100:
        import random
        random.seed(42)
        test_queries = random.sample(test_queries, 100)
        print(f"📊 随机采样了 100 条查询用于评估")
    
    # 2. 加载原始树（用于重新构建）
    print(f"\n📂 加载原始 Risk Tree: {tree_path}")
    if not os.path.exists(tree_path):
        print(f"❌ Risk Tree 文件不存在: {tree_path}")
        return
    
    # 加载原始树以获取数据
    original_tree = RiskTree.load(
        tree_path,
        safety_projector_path=safety_projector_path,
        benign_data_path=benign_data_path,
        enable_safety_projection=True,
        use_single_branch_mode=False
    )
    print("✅ 原始 Risk Tree 加载成功")
    
    # 3. 收集所有样本数据（用于重新构建）
    print("\n📊 收集树中的所有样本数据...")
    all_samples = []
    for category_node in original_tree.root.children:
        for cluster_node in category_node.children:
            for sample_node in cluster_node.children:
                if sample_node.value:
                    all_samples.append({
                        'category': category_node.label,
                        'cluster': cluster_node.label,
                        'value': sample_node.value
                    })
    print(f"✅ 收集了 {len(all_samples)} 个样本")
    
    if len(all_samples) == 0:
        print("⚠️ 警告：没有找到样本数据，无法重新构建树")
        print("   将使用原始树进行评估（threshold 参数不会影响已构建的树）")
        print("   建议：使用包含样本数据的树文件，或者重新训练树")
        return
    
    # 4. 对每个 threshold 值进行评估
    results = []
    
    for threshold in threshold_values:
        print(f"\n{'='*80}")
        print(f"🔬 测试 threshold = {threshold}")
        print(f"{'='*80}")
        
        # 重新构建树（使用新的 threshold）
        print(f"🔨 使用 threshold={threshold} 重新构建树...")
        try:
            # 创建新的树实例
            new_tree = RiskTree(
                threshold=threshold,
                k=original_tree.k,
                score_log_file=f"{output_dir}/score_log_threshold_{threshold}.jsonl",
                safety_projector_path=safety_projector_path,
                enable_safety_projection=True,
                use_single_branch_mode=False
            )
            
            # 重新添加所有样本（模拟构建过程）
            print(f"   添加 {len(all_samples)} 个样本...")
            for sample in tqdm(all_samples, desc=f"构建树 (threshold={threshold})"):
                category_label = sample['category']
                value = sample['value']
                
                # 生成一个临时的 child_label（使用 value 中的 id）
                child_label = value.get('id', 'unknown')
                
                try:
                    new_tree.add_node(category_label, child_label, value)
                except Exception as e:
                    print(f"   ⚠️ 添加样本失败: {e}")
                    continue
            
            # 统计簇信息
            cluster_count, cluster_sizes = count_clusters(new_tree)
            avg_cluster_size = np.mean(cluster_sizes) if cluster_sizes else 0
            std_cluster_size = np.std(cluster_sizes) if cluster_sizes else 0
            max_cluster_size = max(cluster_sizes) if cluster_sizes else 0
            min_cluster_size = min(cluster_sizes) if cluster_sizes else 0
            
            print(f"\n📊 簇统计信息:")
            print(f"   - 簇数量: {cluster_count}")
            print(f"   - 平均簇大小: {avg_cluster_size:.2f}")
            print(f"   - 簇大小标准差: {std_cluster_size:.2f}")
            print(f"   - 最大簇大小: {max_cluster_size}")
            print(f"   - 最小簇大小: {min_cluster_size}")
            
            # 评估检索精度
            print(f"\n🔍 评估检索精度...")
            eval_results = evaluate_retrieval_accuracy(new_tree, test_queries)
            
            result = {
                'threshold': threshold,
                'cluster_count': cluster_count,
                'avg_cluster_size': float(avg_cluster_size),
                'std_cluster_size': float(std_cluster_size),
                'max_cluster_size': int(max_cluster_size),
                'min_cluster_size': int(min_cluster_size),
                'branch_accuracy': eval_results['branch_accuracy'],
                'topic_accuracy': eval_results['topic_accuracy'],
                'total_queries': eval_results['total_queries'],
                'branch_confusion': eval_results['branch_confusion']
            }
            
            results.append(result)
            
            print(f"\n✅ Threshold {threshold} 评估完成:")
            print(f"   - 分支准确率: {result['branch_accuracy']:.4f}")
            print(f"   - 主题准确率: {result['topic_accuracy']:.4f}")
            
        except Exception as e:
            print(f"❌ Threshold {threshold} 评估失败: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # 5. 保存结果
    output_file = os.path.join(output_dir, "threshold_sensitivity_results.json")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"\n{'='*80}")
    print(f"📊 敏感度分析完成！")
    print(f"{'='*80}")
    print(f"\n结果已保存到: {output_file}")
    print(f"\n汇总结果:")
    print(f"{'Threshold':<12} {'簇数量':<10} {'平均簇大小':<12} {'分支准确率':<12} {'主题准确率':<12}")
    print(f"{'-'*60}")
    for r in results:
        print(f"{r['threshold']:<12.2f} {r['cluster_count']:<10} {r['avg_cluster_size']:<12.2f} "
              f"{r['branch_accuracy']:<12.4f} {r['topic_accuracy']:<12.4f}")
    
    return results


def main():
    """主函数"""
    import argparse
    
    parser = argparse.ArgumentParser(description="RiskTree Threshold 敏感度分析")
    parser.add_argument(
        "--tree_path",
        type=str,
        default="src/final_memory_20260119_2202.pkl",
        help="RiskTree pickle 文件路径"
    )
    parser.add_argument(
        "--test_queries",
        type=str,
        default="test_queries.json",
        help="测试查询 JSON 文件路径"
    )
    parser.add_argument(
        "--threshold_values",
        type=float,
        nargs="+",
        default=[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        help="要测试的 threshold 值列表"
    )
    parser.add_argument(
        "--safety_projector_path",
        type=str,
        default="src/models/safety_projector_metric.pth",
        help="Safety Projector 模型路径"
    )
    parser.add_argument(
        "--benign_data_path",
        type=str,
        default="/path/to/agentharm/agent_align_data_v3.json",
        help="良性数据路径"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./sensitivity_results/threshold",
        help="结果输出目录"
    )
    
    args = parser.parse_args()
    
    analyze_threshold_sensitivity(
        tree_path=args.tree_path,
        test_queries_path=args.test_queries,
        threshold_values=args.threshold_values,
        safety_projector_path=args.safety_projector_path,
        benign_data_path=args.benign_data_path,
        output_dir=args.output_dir
    )


if __name__ == '__main__':
    main()
