#!/usr/bin/env python3
"""
三种检索方法时延对比脚本

对比 Memory Tree、RAG、A-mem 三种方法的：
1. 检索时延 (Retrieval Latency)
2. 推理时延 (Inference Latency) - 使用同一个 LLM 进行评估
"""

import json
import sys
import os
import pickle
import time
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed

# 添加 src 目录到路径
src_path = Path(__file__).parent / 'src'
sys.path.insert(0, str(src_path))

from risk_tree import RiskTree


def load_resources(tree_path: str, safety_projector_path: str, benign_data_path: str, amem_path: str):
    """加载所有资源"""
    
    # 加载 RiskTree
    risk_tree = None
    print(f"\n📂 加载 Risk Tree: {tree_path}")
    start = time.time()
    risk_tree = RiskTree.load(
        tree_path,
        safety_projector_path=safety_projector_path,
        benign_data_path=benign_data_path,
        enable_safety_projection=True,
        use_single_branch_mode=False,
        regenerate_boundary_rules=False
    )
    tree_load_time = time.time() - start
    print(f"✅ Risk Tree 加载成功 (耗时: {tree_load_time:.2f}s)")
    
    # 加载 A-mem
    print(f"\n📂 加载 A-mem: {amem_path}")
    start = time.time()
    try:
        with open(amem_path, 'rb') as f:
            amem_system = pickle.load(f)
        amem_load_time = time.time() - start
        print(f"✅ A-mem 加载成功 (耗时: {amem_load_time:.2f}s)")
    except Exception as e:
        print(f"⚠️ A-mem 加载失败: {e}")
        amem_system = None
        amem_load_time = 0
    
    return risk_tree, amem_system, tree_load_time, amem_load_time


def memory_tree_retrieval(risk_tree: RiskTree, query: str, top_k: int = 3) -> Dict[str, Any]:
    """Memory Tree 检索"""
    return risk_tree.retrieve_query(
        [{"role": "user", "content": query}], 
        prompt=False, 
        top_k_clusters=top_k,
        top_k_leaves_per_cluster=1
    )


def rag_retrieval(risk_tree: RiskTree, query: str, top_k: int = 3) -> Dict[str, Any]:
    """RAG 检索"""
    if not risk_tree.defense_vector_store:
        return {"dynamic_rules": [], "benign_boundary_rules": [], "topic_label": "Unknown"}
    
    results = risk_tree.defense_vector_store.similarity_search(query, k=top_k)
    rules = []
    for doc in results:
        content = doc.page_content.replace('\n', ' ').strip()
        intent = doc.metadata.get('intent', '') or doc.metadata.get('query', '')
        pattern = doc.metadata.get('pattern', '')
        response = doc.metadata.get('response', '')
        
        if response:
            rule_text = f"Intent: {intent}\nPattern: {pattern}\nAssistant Outcome: {response}\nContent: {content}"
        else:
            rule_text = f"Intent: {intent}\nPattern: {pattern}\nContent: {content}"
        rules.append(rule_text)
    
    return {"dynamic_rules": rules, "benign_boundary_rules": [], "topic_label": "RAG"}


def amem_retrieval(amem_system, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
    """A-mem 检索"""
    if amem_system is None:
        return []
    try:
        results = amem_system.search(query, k=top_k)
        return results if results else []
    except Exception as e:
        print(f"⚠️ A-mem 检索失败: {e}")
        return []


def extract_rules_text(result: Any, method_name: str, max_rules: int = 3) -> str:
    """从检索结果中提取规则文本"""
    if method_name == "A-mem":
        if isinstance(result, list):
            rules = []
            for mem_dict in result[:max_rules]:
                content = mem_dict.get('content', '')
                if content:
                    rules.append(content[:500])  # 截断
            return "\n---\n".join(rules)
        return ""
    
    if method_name == "Memory Tree":
        benign_boundary_rules = result.get('benign_boundary_rules', [])
        rules = []
        for rule_dict in benign_boundary_rules[:max_rules]:
            if isinstance(rule_dict, dict):
                harmful_rule = rule_dict.get('harmful_rule', '').strip()
                exemptions = rule_dict.get('exemptions', '').strip()
                rule_parts = []
                if harmful_rule:
                    rule_parts.append(f"[Harmful Rule] {harmful_rule}")
                if exemptions:
                    rule_parts.append(f"[Exemption] {exemptions}")
                if rule_parts:
                    rules.append("\n".join(rule_parts))
        return "\n---\n".join(rules)
    
    # RAG
    dynamic_rules = result.get('dynamic_rules', [])
    rules = [str(r)[:500] for r in dynamic_rules[:max_rules]]
    return "\n---\n".join(rules)


def llm_inference(client: OpenAI, model_name: str, query: str, rules_text: str) -> float:
    """LLM 推理（返回推理时延）"""
    prompt = f"""Based on the following safety rules, determine if the user query is safe or unsafe.

Safety Rules:
{rules_text}

User Query: {query}

Output only "safe" or "unsafe"."""
    
    start = time.time()
    try:
        response = client.chat.completions.create(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,
            max_tokens=10,
            timeout=30
        )
        inference_time = time.time() - start
        return inference_time
    except Exception as e:
        print(f"⚠️ LLM 推理失败: {e}")
        return time.time() - start


def benchmark_single_query(
    query: str,
    risk_tree: RiskTree,
    amem_system,
    llm_client: OpenAI,
    model_name: str,
    top_k: int = 3
) -> Dict[str, Dict[str, float]]:
    """对单个查询进行基准测试"""
    results = {}
    
    # Memory Tree
    start = time.time()
    mt_result = memory_tree_retrieval(risk_tree, query, top_k)
    mt_retrieval_time = time.time() - start
    mt_rules = extract_rules_text(mt_result, "Memory Tree", top_k)
    mt_inference_time = llm_inference(llm_client, model_name, query, mt_rules) if mt_rules else 0
    results["Memory Tree"] = {
        "retrieval_time": mt_retrieval_time,
        "inference_time": mt_inference_time,
        "total_time": mt_retrieval_time + mt_inference_time,
        "rules_length": len(mt_rules)
    }
    
    # RAG
    start = time.time()
    rag_result = rag_retrieval(risk_tree, query, top_k)
    rag_retrieval_time = time.time() - start
    rag_rules = extract_rules_text(rag_result, "RAG", top_k)
    rag_inference_time = llm_inference(llm_client, model_name, query, rag_rules) if rag_rules else 0
    results["RAG"] = {
        "retrieval_time": rag_retrieval_time,
        "inference_time": rag_inference_time,
        "total_time": rag_retrieval_time + rag_inference_time,
        "rules_length": len(rag_rules)
    }
    
    # A-mem
    if amem_system:
        start = time.time()
        amem_result = amem_retrieval(amem_system, query, top_k)
        amem_retrieval_time = time.time() - start
        amem_rules = extract_rules_text(amem_result, "A-mem", top_k)
        amem_inference_time = llm_inference(llm_client, model_name, query, amem_rules) if amem_rules else 0
        results["A-mem"] = {
            "retrieval_time": amem_retrieval_time,
            "inference_time": amem_inference_time,
            "total_time": amem_retrieval_time + amem_inference_time,
            "rules_length": len(amem_rules)
        }
    
    return results


def run_benchmark(
    test_queries_path: str,
    tree_path: str,
    safety_projector_path: str,
    benign_data_path: str,
    amem_path: str,
    llm_api_url: str,
    llm_api_key: str,
    model_name: str,
    top_k: int = 3,
    sample_size: int = 50,
    output_path: str = "./latency_comparison_results.json"
):
    """运行基准测试"""
    # 加载资源
    risk_tree, amem_system, tree_load_time, amem_load_time = load_resources(
        tree_path, safety_projector_path, benign_data_path, amem_path
    )
    
    # 加载测试查询
    print(f"\n📋 加载测试查询: {test_queries_path}")
    with open(test_queries_path, 'r', encoding='utf-8') as f:
        test_queries = json.load(f)
    
    if sample_size and sample_size < len(test_queries):
        import random
        random.seed(42)
        test_queries = random.sample(test_queries, sample_size)
    print(f"📊 使用 {len(test_queries)} 条查询进行测试")
    
    # 初始化 LLM 客户端
    print(f"\n🤖 初始化 LLM 客户端 (模型: {model_name})...")
    llm_client = OpenAI(base_url=llm_api_url, api_key=llm_api_key)
    
    all_results = []
    methods = ["Memory Tree", "RAG"]
    if amem_system:
        methods.append("A-mem")
    
    for idx, test_case in enumerate(test_queries):
        query = test_case.get('user_query', '')
        if not query:
            continue
        
        if (idx + 1) % 10 == 0:
            print(f"进度: {idx + 1}/{len(test_queries)}")
        
        result = benchmark_single_query(
            query, risk_tree, amem_system, llm_client, model_name, top_k
        )
        result["query_id"] = idx + 1
        result["query"] = query[:100] + "..." if len(query) > 100 else query
        all_results.append(result)

    
    summary = {}
    for method in methods:
        retrieval_times = [r[method]["retrieval_time"] for r in all_results if method in r]
        inference_times = [r[method]["inference_time"] for r in all_results if method in r]
        total_times = [r[method]["total_time"] for r in all_results if method in r]
        rules_lengths = [r[method]["rules_length"] for r in all_results if method in r]
        
        summary[method] = {
            "retrieval_time_mean": float(np.mean(retrieval_times)),
            "retrieval_time_std": float(np.std(retrieval_times)),
            "retrieval_time_p50": float(np.percentile(retrieval_times, 50)),
            "retrieval_time_p95": float(np.percentile(retrieval_times, 95)),
            "inference_time_mean": float(np.mean(inference_times)),
            "inference_time_std": float(np.std(inference_times)),
            "inference_time_p50": float(np.percentile(inference_times, 50)),
            "inference_time_p95": float(np.percentile(inference_times, 95)),
            "total_time_mean": float(np.mean(total_times)),
            "total_time_std": float(np.std(total_times)),
            "total_time_p50": float(np.percentile(total_times, 50)),
            "total_time_p95": float(np.percentile(total_times, 95)),
            "rules_length_mean": float(np.mean(rules_lengths)),
            "num_queries": len(retrieval_times)
        }
        
    for method in methods:
        s = summary[method]
        print(f"{method:<15} {s['retrieval_time_mean']*1000:<18.2f} {s['inference_time_mean']*1000:<18.2f} {s['total_time_mean']*1000:<18.2f} {s['rules_length_mean']:<12.0f}")
    
    # 保存结果
    output_data = {
        "summary": summary,
        "config": {
            "model_name": model_name,
            "top_k": top_k,
            "sample_size": len(test_queries),
            "tree_load_time": tree_load_time,
            "amem_load_time": amem_load_time
        },
        "detailed_results": all_results
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    print(f"\n✅ 结果已保存到: {output_path}")
    
    return summary


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="三种检索方法时延对比")
    parser.add_argument("--test_queries", type=str, default="test_queries.json")
    parser.add_argument("--tree_path", type=str, default="final_memory_20260124_2322.pkl")
    parser.add_argument("--safety_projector_path", type=str, default="src/models/safety_projector_metric.pth")
    parser.add_argument("--benign_data_path", type=str, default="/path/to/agentharm/agent_align_data_v3.json")
    parser.add_argument("--amem_path", type=str, default="memory_system_new.pkl")
    parser.add_argument("--llm_api_url", type=str, default="http://127.0.0.1:7008/v1")
    parser.add_argument("--llm_api_key", type=str, default="EMPTY")
    parser.add_argument("--model_name", type=str, default="qwen-72b")
    parser.add_argument("--top_k", type=int, default=3)
    parser.add_argument("--sample_size", type=int, default=50)
    parser.add_argument("--output_path", type=str, default="./latency_comparison_results.json")
    
    args = parser.parse_args()
    
    run_benchmark(
        test_queries_path=args.test_queries,
        tree_path=args.tree_path,
        safety_projector_path=args.safety_projector_path,
        benign_data_path=args.benign_data_path,
        amem_path=args.amem_path,
        llm_api_url=args.llm_api_url,
        llm_api_key=args.llm_api_key,
        model_name=args.model_name,
        top_k=args.top_k,
        sample_size=args.sample_size,
        output_path=args.output_path
    )


if __name__ == '__main__':
    main()
