#!/usr/bin/env python3
"""
敏感度实验：测试 entropy threshold 和 similarity threshold 对 cluster 数量的影响

用法:
    python sensitivity_experiment.py --input src/attack_results.json --output sensitivity_results.json
"""

import json
import sys
import os
import argparse
import pickle
from pathlib import Path
from typing import Dict, List, Any
from tqdm import tqdm
from datetime import datetime
import logging

# 添加 src 目录到路径
src_path = Path(__file__).parent / 'src'
sys.path.insert(0, str(src_path))

from risk_tree import RiskTree
from memory_defender import MemoryDefender
from openai import OpenAI


def count_clusters(tree: RiskTree) -> Dict[str, int]:
    """
    统计树中的 cluster 数量
    
    Returns:
        dict: 包含总 cluster 数、每个 category 的 cluster 数等统计信息
    """
    stats = {
        'total_clusters': 0,
        'total_leaves': 0,
        'categories': {},
        'clusters_per_category': {}
    }
    
    root = tree.root
    
    # 遍历所有 category 节点
    for cat_node in root.children:
        cat_name = cat_node.label
        cluster_count = 0
        leaf_count = 0
        
        # 遍历该 category 下的所有 cluster 节点
        for cluster_node in cat_node.children:
            cluster_count += 1
            # 统计该 cluster 下的叶子节点数
            leaf_count += len(cluster_node.children)
        
        stats['categories'][cat_name] = {
            'clusters': cluster_count,
            'leaves': leaf_count
        }
        stats['total_clusters'] += cluster_count
        stats['total_leaves'] += leaf_count
    
    return stats


def build_tree_from_data(
    data_path: str,
    threshold: float,
    attack_similarity_threshold: float,
    llm_client: OpenAI,
    model_name: str = "qwen-7b",
    sample_limit: int = None,
    logger = None
) -> RiskTree:
    """
    从数据文件构建 RiskTree
    
    Args:
        data_path: 数据文件路径
        threshold: 规则信息增益阈值（entropy threshold）
        attack_similarity_threshold: attack_prompt 相似度阈值
        llm_client: LLM 客户端
        model_name: 模型名称
        sample_limit: 限制处理的样本数量（None 表示全部）
        logger: 日志记录器（可选）
    
    Returns:
        RiskTree: 构建好的树
    """
    if logger is None:
        logger = logging.getLogger(__name__)
    
    # 初始化 RiskTree（使用指定的阈值）
    tree = RiskTree(
        threshold=threshold,
        attack_similarity_threshold=attack_similarity_threshold,
        safety_projector_path=None,  # 不使用 safety projector
        enable_safety_projection=False
    )
    
    # 初始化 MemoryDefender
    defender = MemoryDefender(
        risk_tree=tree,
        llm_client=llm_client,
        model_name=model_name,
        enable_attack_enhancement=True
    )
    
    # 加载数据
    logger.info(f"📂 加载数据: {data_path}")
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    if sample_limit and sample_limit < len(data):
        logger.info(f"📊 限制处理前 {sample_limit} 个样本（总共 {len(data)} 个）")
        data = data[:sample_limit]
    
    logger.info(f"✅ 加载了 {len(data)} 个样本")
    
    # 处理每个样本
    logger.info(f"🔨 构建树（threshold={threshold}, attack_sim_threshold={attack_similarity_threshold}）...")
    for sample in tqdm(data, desc="Processing samples"):
        try:
            # 使用 MemoryDefender 的 evolve_step 方法处理样本
            defender.evolve_step(sample)
        except Exception as e:
            logger.warning(f"⚠️ 处理样本失败: {e}")
            continue
    
    return tree


def run_sensitivity_experiment(
    input_data_path: str,
    output_path: str,
    llm_api_url: str = "http://127.0.0.1:8030/v1",
    llm_api_key: str = "EMPTY",
    model_name: str = "qwen-7b",
    sample_limit: int = None,
    threshold_range: List[float] = None,
    attack_similarity_range: List[float] = None
):
    """
    运行敏感度实验
    
    Args:
        input_data_path: 输入数据文件路径
        output_path: 结果输出路径
        llm_api_url: LLM API URL
        llm_api_key: LLM API Key
        model_name: 模型名称
        sample_limit: 限制处理的样本数量
        threshold_range: entropy threshold 范围（默认: [0.1, 0.2, 0.3, 0.4, 0.5]）
        attack_similarity_range: attack similarity threshold 范围（默认: [0.3, 0.4, 0.5, 0.6, 0.7]）
    """
    # 默认参数范围
    if threshold_range is None:
        threshold_range = [0.1, 0.2, 0.3, 0.4, 0.5]
    if attack_similarity_range is None:
        attack_similarity_range = [0.3, 0.4, 0.5, 0.6, 0.7]
    
    # 创建输出目录结构
    output_dir = Path(output_path).parent
    output_basename = Path(output_path).stem
    
    # 创建树文件目录
    trees_dir = output_dir / f"{output_basename}_trees"
    trees_dir.mkdir(parents=True, exist_ok=True)
    
    # 创建中间过程log文件路径
    log_path = output_path.replace('.json', '_process.log')
    
    # 创建最终结果文件路径
    final_results_path = output_path.replace('.json', '_final_results.json')
    
    # 配置日志记录中间过程
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    
    # 清除已有的handlers，避免重复
    logger.handlers.clear()
    
    # 添加文件handler
    file_handler = logging.FileHandler(log_path, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    
    # 添加控制台handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    # 初始化 LLM 客户端
    llm_client = OpenAI(base_url=llm_api_url, api_key=llm_api_key)
    
    # 存储所有实验结果
    results = {
        'experiment_config': {
            'input_data': input_data_path,
            'sample_limit': sample_limit,
            'threshold_range': threshold_range,
            'attack_similarity_range': attack_similarity_range,
            'model_name': model_name
        },
        'results': []
    }
    
    # 遍历所有参数组合
    total_combinations = len(threshold_range) * len(attack_similarity_range)
    logger.info(f"\n🔬 开始敏感度实验")
    logger.info(f"   参数组合数: {total_combinations}")
    logger.info(f"   Threshold 范围: {threshold_range}")
    logger.info(f"   Attack Similarity 范围: {attack_similarity_range}")
    logger.info(f"   样本限制: {sample_limit if sample_limit else '全部'}")
    logger.info(f"   树文件目录: {trees_dir}")
    logger.info(f"   中间过程日志: {log_path}")
    logger.info(f"   最终结果文件: {final_results_path}\n")
    
    experiment_idx = 0
    for threshold in threshold_range:
        for attack_sim_threshold in attack_similarity_range:
            experiment_idx += 1
            logger.info(f"\n{'='*60}")
            logger.info(f"实验 {experiment_idx}/{total_combinations}")
            logger.info(f"Threshold (entropy): {threshold}")
            logger.info(f"Attack Similarity Threshold: {attack_sim_threshold}")
            logger.info(f"{'='*60}\n")
            
            try:
                # 构建树
                tree = build_tree_from_data(
                    data_path=input_data_path,
                    threshold=threshold,
                    attack_similarity_threshold=attack_sim_threshold,
                    llm_client=llm_client,
                    model_name=model_name,
                    sample_limit=sample_limit,
                    logger=logger
                )
                
                # 统计 cluster 数量
                stats = count_clusters(tree)
                
                # 保存树到单独目录
                tree_filename = trees_dir / f"tree_th{threshold:.2f}_sim{attack_sim_threshold:.2f}.pkl"
                tree.save(str(tree_filename))
                logger.info(f"💾 树已保存到: {tree_filename}")
                
                # 保存结果
                result = {
                    'threshold': threshold,
                    'attack_similarity_threshold': attack_sim_threshold,
                    'tree_file': str(tree_filename),
                    'cluster_stats': stats,
                    'total_clusters': stats['total_clusters'],
                    'total_leaves': stats['total_leaves'],
                    'avg_leaves_per_cluster': stats['total_leaves'] / stats['total_clusters'] if stats['total_clusters'] > 0 else 0
                }
                
                results['results'].append(result)
                
                # 打印结果
                logger.info(f"\n✅ 实验结果:")
                logger.info(f"   总 Cluster 数: {stats['total_clusters']}")
                logger.info(f"   总 Leaf 数: {stats['total_leaves']}")
                logger.info(f"   平均每个 Cluster 的 Leaf 数: {result['avg_leaves_per_cluster']:.2f}")
                logger.info(f"   各 Category 的 Cluster 数:")
                for cat_name, cat_stats in stats['categories'].items():
                    logger.info(f"     - {cat_name}: {cat_stats['clusters']} clusters, {cat_stats['leaves']} leaves")
                
                # 保存中间结果（防止中途失败）
                with open(output_path, 'w', encoding='utf-8') as f:
                    json.dump(results, f, indent=2, ensure_ascii=False)
                logger.info(f"\n💾 中间结果已保存到: {output_path}")
                
            except Exception as e:
                logger.error(f"❌ 实验失败: {e}")
                import traceback
                logger.error(traceback.format_exc())
                # 记录失败
                results['results'].append({
                    'threshold': threshold,
                    'attack_similarity_threshold': attack_sim_threshold,
                    'error': str(e),
                    'cluster_stats': None,
                    'tree_file': None
                })
                continue
    
    # 最终保存结果到独立文件
    with open(final_results_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    logger.info(f"\n{'='*60}")
    logger.info(f"🎉 所有实验完成！")
    logger.info(f"💾 最终结果已保存到: {final_results_path}")
    logger.info(f"📁 树文件目录: {trees_dir}")
    logger.info(f"📝 中间过程日志: {log_path}")
    logger.info(f"{'='*60}\n")
    
    # 生成摘要
    logger.info("📊 实验摘要:")
    logger.info(f"{'Threshold':<12} {'Attack Sim':<12} {'Clusters':<10} {'Leaves':<10} {'Avg Leaves/Cluster':<18}")
    logger.info("-" * 70)
    for result in results['results']:
        if 'error' not in result:
            logger.info(f"{result['threshold']:<12.2f} {result['attack_similarity_threshold']:<12.2f} "
                  f"{result['total_clusters']:<10} {result['total_leaves']:<10} "
                  f"{result['avg_leaves_per_cluster']:<18.2f}")
        else:
            logger.info(f"{result['threshold']:<12.2f} {result['attack_similarity_threshold']:<12.2f} "
                  f"{'ERROR':<10}")
    
    # 保存最终摘要到单独文件
    summary_path = output_path.replace('.json', '_summary.txt')
    with open(summary_path, 'w', encoding='utf-8') as f:
        f.write("=" * 70 + "\n")
        f.write("敏感度实验摘要\n")
        f.write("=" * 70 + "\n\n")
        f.write(f"实验配置:\n")
        f.write(f"  输入文件: {results['experiment_config']['input_data']}\n")
        f.write(f"  样本限制: {results['experiment_config']['sample_limit'] or '全部'}\n")
        f.write(f"  Threshold 范围: {results['experiment_config']['threshold_range']}\n")
        f.write(f"  Attack Similarity 范围: {results['experiment_config']['attack_similarity_range']}\n")
        f.write(f"  模型: {results['experiment_config']['model_name']}\n")
        f.write(f"  树文件目录: {trees_dir}\n\n")
        f.write(f"{'Threshold':<12} {'Attack Sim':<12} {'Clusters':<10} {'Leaves':<10} {'Avg Leaves/Cluster':<18}\n")
        f.write("-" * 70 + "\n")
        for result in results['results']:
            if 'error' not in result:
                f.write(f"{result['threshold']:<12.2f} {result['attack_similarity_threshold']:<12.2f} "
                       f"{result['total_clusters']:<10} {result['total_leaves']:<10} "
                       f"{result['avg_leaves_per_cluster']:<18.2f}\n")
            else:
                f.write(f"{result['threshold']:<12.2f} {result['attack_similarity_threshold']:<12.2f} "
                       f"{'ERROR':<10} - {result.get('error', 'Unknown error')}\n")
    
    logger.info(f"📄 摘要已保存到: {summary_path}")
    
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="敏感度实验：测试 entropy 和 similarity 阈值对 cluster 数量的影响")
    
    parser.add_argument(
        "--input",
        type=str,
        default="src/attack_results.json",
        help="输入数据文件路径（默认: src/attack_results.json）"
    )
    
    parser.add_argument(
        "--output",
        type=str,
        default="sensitivity_results.json",
        help="结果输出文件路径（默认: sensitivity_results.json）"
    )
    
    parser.add_argument(
        "--llm_api_url",
        type=str,
        default="http://127.0.0.1:8030/v1",
        help="LLM API URL（默认: http://127.0.0.1:8030/v1）"
    )
    
    parser.add_argument(
        "--llm_api_key",
        type=str,
        default="EMPTY",
        help="LLM API Key（默认: EMPTY）"
    )
    
    parser.add_argument(
        "--model_name",
        type=str,
        default="qwen-7b",
        help="模型名称（默认: qwen-7b）"
    )
    
    parser.add_argument(
        "--sample_limit",
        type=int,
        default=None,
        help="限制处理的样本数量（默认: None，处理全部）"
    )
    
    parser.add_argument(
        "--threshold_range",
        type=float,
        nargs='+',
        default=[0.1, 0.2, 0.3, 0.4, 0.5],
        help="Entropy threshold 范围（默认: 0.1 0.2 0.3 0.4 0.5）"
    )
    
    parser.add_argument(
        "--attack_similarity_range",
        type=float,
        nargs='+',
        default=[0.3, 0.4, 0.5, 0.6, 0.7],
        help="Attack similarity threshold 范围（默认: 0.3 0.4 0.5 0.6 0.7）"
    )
    
    args = parser.parse_args()
    
    # 运行实验
    results = run_sensitivity_experiment(
        input_data_path=args.input,
        output_path=args.output,
        llm_api_url=args.llm_api_url,
        llm_api_key=args.llm_api_key,
        model_name=args.model_name,
        sample_limit=args.sample_limit,
        threshold_range=args.threshold_range,
        attack_similarity_range=args.attack_similarity_range
    )
