#!/usr/bin/env python3
"""
超参数敏感性实验脚本

测试两个关键超参数的不同组合：
1. attack_similarity_threshold: attack_prompt相似度阈值，决定是否进入规则判断
2. threshold: 规则信息增益阈值，决定是否合并规则

评估指标：
- 树结构指标：cluster数量、平均cluster大小、树深度等
- 性能指标：建树时间、LLM调用次数
- 质量指标：检索精度（如果有测试集）
"""

import json
import sys
import os
import pickle
import time
from pathlib import Path
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from datetime import datetime
import argparse

# 添加 src 目录到路径
src_path = Path(__file__).parent / 'src'
sys.path.insert(0, str(src_path))

from risk_tree import RiskTree
from memory_defender import main_evolution


def count_tree_statistics(tree):
    """
    统计树的结构信息
    
    Returns:
        dict: 包含各种统计指标
    """
    total_clusters = 0
    total_samples = 0
    cluster_sizes = []
    category_cluster_counts = defaultdict(int)
    
    for category_node in tree.root.children:
        category_name = category_node.label
        for cluster_node in category_node.children:
            total_clusters += 1
            category_cluster_counts[category_name] += 1
            sample_count = len(cluster_node.children)
            cluster_sizes.append(sample_count)
            total_samples += sample_count
    
    avg_cluster_size = np.mean(cluster_sizes) if cluster_sizes else 0
    std_cluster_size = np.std(cluster_sizes) if cluster_sizes else 0
    max_cluster_size = max(cluster_sizes) if cluster_sizes else 0
    min_cluster_size = min(cluster_sizes) if cluster_sizes else 0
    
    return {
        'total_categories': len(tree.root.children),
        'total_clusters': total_clusters,
        'total_samples': total_samples,
        'avg_cluster_size': float(avg_cluster_size),
        'std_cluster_size': float(std_cluster_size),
        'max_cluster_size': int(max_cluster_size),
        'min_cluster_size': int(min_cluster_size),
        'clusters_per_category': dict(category_cluster_counts)
    }


def run_single_experiment(
    attack_similarity_threshold,
    info_gain_threshold,
    attack_results_file,
    small_batch=True,
    max_samples=None,
    llm_port=8030,
    model_name="qwen-72b",
    safety_projector_path=None,
    output_dir="./hyperparameter_sensitivity_results"
):
    """
    运行单个超参数组合的实验
    
    Args:
        attack_similarity_threshold: attack_prompt相似度阈值
        info_gain_threshold: 规则信息增益阈值
        attack_results_file: 攻击结果文件路径
        small_batch: 是否使用小批量（前100条）
        max_samples: 最大样本数量（None表示使用全部）
        llm_port: LLM API端口
        model_name: LLM模型名称
        safety_projector_path: Safety Projector模型路径
        output_dir: 结果输出目录
    
    Returns:
        dict: 实验结果
    """
    print(f"\n{'='*80}")
    print(f"🔬 实验参数:")
    print(f"   - attack_similarity_threshold: {attack_similarity_threshold}")
    print(f"   - info_gain_threshold: {info_gain_threshold}")
    print(f"{'='*80}")
    
    # 创建实验ID
    exp_id = f"as{attack_similarity_threshold:.2f}_ig{info_gain_threshold:.2f}"
    exp_output_dir = os.path.join(output_dir, exp_id)
    os.makedirs(exp_output_dir, exist_ok=True)
    
    # 创建专门的trees目录用于保存所有树文件
    trees_dir = os.path.join(output_dir, "trees")
    os.makedirs(trees_dir, exist_ok=True)
    
    # 记录开始时间
    start_time = time.time()
    
    try:
        # 1. 创建新的树实例
        print(f"\n📦 创建新的 RiskTree 实例...")
        tree = RiskTree(
            threshold=info_gain_threshold,
            attack_similarity_threshold=attack_similarity_threshold,
            safety_projector_path=safety_projector_path,
            llm_port=llm_port,
            model_name=model_name
        )
        
        # 2. 加载攻击数据
        print(f"\n📂 加载攻击数据: {attack_results_file}")
        if not os.path.exists(attack_results_file):
            raise FileNotFoundError(f"攻击结果文件不存在: {attack_results_file}")
        
        with open(attack_results_file, "r", encoding="utf-8") as f:
            attacks = json.load(f)
        
        # 限制样本数量
        if max_samples:
            attacks = attacks[:max_samples]
        elif small_batch:
            attacks = attacks[:100]
        
        print(f"✅ 加载了 {len(attacks)} 个攻击样本")
        
        # 3. 初始化 MemoryDefender
        from openai import OpenAI
        client = OpenAI(base_url=f"http://127.0.0.1:{llm_port}/v1", api_key="EMPTY")
        from memory_defender import MemoryDefender
        defender = MemoryDefender(tree, client, model_name, enable_attack_enhancement=True)
        
        # 4. 建树过程
        print(f"\n🌳 开始建树过程...")
        llm_call_count = 0  # 记录LLM调用次数（需要修改代码来统计）
        processed_count = 0
        failed_count = 0
        
        for i, attack in enumerate(tqdm(attacks, desc="处理攻击样本")):
            try:
                # 准备样本数据
                sample = {
                    'original_id': attack.get('id', f'sample_{i}'),
                    'category': attack.get('category', 'unknown'),
                    'sub_category': attack.get('sub_category', 'unknown'),
                    'attack_prompt': attack.get('attack_prompt', attack.get('mutated_prompt', '')),
                    'mutated_prompt': attack.get('mutated_prompt', ''),
                    'original_prompt': attack.get('original_prompt', '')
                }
                
                # 调用 evolve_step
                defender.evolve_step(sample, skip_simulation=True, skip_rule_generation=True)
                processed_count += 1
                
            except Exception as e:
                print(f"\n⚠️ 处理样本 {i} 失败: {e}")
                failed_count += 1
                continue
        
        # 5. 统计树结构
        print(f"\n📊 统计树结构...")
        tree_stats = count_tree_statistics(tree)
        
        # 6. 保存树到专门的trees目录
        trees_dir = os.path.join(output_dir, "trees")
        os.makedirs(trees_dir, exist_ok=True)
        tree_file = os.path.join(trees_dir, f"tree_{exp_id}.pkl")
        tree_file_abs = os.path.abspath(tree_file)
        
        print(f"\n💾 保存树文件...")
        with open(tree_file, 'wb') as f:
            pickle.dump(tree, f)
        print(f"✅ 树已保存到: {tree_file_abs}")
        
        # 同时在实验目录下保存一份副本（可选）
        tree_file_copy = os.path.join(exp_output_dir, f"tree_{exp_id}.pkl")
        with open(tree_file_copy, 'wb') as f:
            pickle.dump(tree, f)
        
        # 7. 计算总时间
        elapsed_time = time.time() - start_time
        
        # 8. 构建结果
        result = {
            'attack_similarity_threshold': attack_similarity_threshold,
            'info_gain_threshold': info_gain_threshold,
            'experiment_id': exp_id,
            'tree_file': tree_file,
            'elapsed_time': elapsed_time,
            'processed_samples': processed_count,
            'failed_samples': failed_count,
            'total_samples': len(attacks),
            **tree_stats
        }
        
        print(f"\n✅ 实验完成!")
        print(f"   - 处理样本数: {processed_count}/{len(attacks)}")
        print(f"   - 失败样本数: {failed_count}")
        print(f"   - 总cluster数: {tree_stats['total_clusters']}")
        print(f"   - 平均cluster大小: {tree_stats['avg_cluster_size']:.2f}")
        print(f"   - 总耗时: {elapsed_time:.2f}秒")
        
        return result
        
    except Exception as e:
        print(f"\n❌ 实验失败: {e}")
        import traceback
        traceback.print_exc()
        return {
            'attack_similarity_threshold': attack_similarity_threshold,
            'info_gain_threshold': info_gain_threshold,
            'experiment_id': exp_id,
            'error': str(e),
            'elapsed_time': time.time() - start_time
        }


def run_hyperparameter_sensitivity_experiment(
    attack_similarity_thresholds,
    info_gain_thresholds,
    attack_results_file,
    small_batch=True,
    max_samples=None,
    llm_port=8030,
    model_name="qwen-72b",
    safety_projector_path=None,
    output_dir="./hyperparameter_sensitivity_results"
):
    """
    运行超参数敏感性实验
    
    Args:
        attack_similarity_thresholds: attack_similarity_threshold 值列表
        info_gain_thresholds: info_gain_threshold 值列表
        attack_results_file: 攻击结果文件路径
        small_batch: 是否使用小批量
        max_samples: 最大样本数量
        llm_port: LLM API端口
        model_name: LLM模型名称
        safety_projector_path: Safety Projector模型路径
        output_dir: 结果输出目录
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # 创建专门的trees目录
    trees_dir = os.path.join(output_dir, "trees")
    os.makedirs(trees_dir, exist_ok=True)
    print(f"📁 树文件将保存到: {os.path.abspath(trees_dir)}")
    
    # 生成所有参数组合
    param_combinations = [
        (as_thresh, ig_thresh)
        for as_thresh in attack_similarity_thresholds
        for ig_thresh in info_gain_thresholds
    ]
    
    print(f"\n{'='*80}")
    print(f"🚀 超参数敏感性实验")
    print(f"{'='*80}")
    print(f"📋 实验配置:")
    print(f"   - attack_similarity_thresholds: {attack_similarity_thresholds}")
    print(f"   - info_gain_thresholds: {info_gain_thresholds}")
    print(f"   - 总组合数: {len(param_combinations)}")
    print(f"   - 攻击数据文件: {attack_results_file}")
    print(f"   - 样本限制: {'small_batch (100)' if small_batch else f'max_samples={max_samples}' if max_samples else '全部'}")
    print(f"{'='*80}")
    
    # 运行所有实验
    results = []
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    for i, (as_thresh, ig_thresh) in enumerate(param_combinations, 1):
        print(f"\n\n{'#'*80}")
        print(f"实验 {i}/{len(param_combinations)}")
        print(f"{'#'*80}")
        
        result = run_single_experiment(
            attack_similarity_threshold=as_thresh,
            info_gain_threshold=ig_thresh,
            attack_results_file=attack_results_file,
            small_batch=small_batch,
            max_samples=max_samples,
            llm_port=llm_port,
            model_name=model_name,
            safety_projector_path=safety_projector_path,
            output_dir=output_dir
        )
        
        results.append(result)
        
        # 每完成一个实验就保存一次（防止中途失败）
        results_file = os.path.join(output_dir, f"results_{timestamp}.json")
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"💾 中间结果已保存到: {results_file}")
    
    # 保存最终结果
    final_results_file = os.path.join(output_dir, f"final_results_{timestamp}.json")
    with open(final_results_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    # 打印汇总结果
    print(f"\n\n{'='*80}")
    print(f"📊 实验汇总结果")
    print(f"{'='*80}")
    print(f"\n结果已保存到: {final_results_file}")
    print(f"所有树文件保存在: {os.path.abspath(trees_dir)}")
    print(f"   - 共 {len([r for r in results if 'error' not in r])} 个树文件")
    
    # 创建结果表格
    print(f"\n{'攻击相似度阈值':<15} {'信息增益阈值':<15} {'Cluster数':<12} {'平均大小':<12} {'总样本数':<12} {'耗时(秒)':<12}")
    print(f"{'-'*80}")
    
    for r in results:
        if 'error' not in r:
            print(f"{r['attack_similarity_threshold']:<15.2f} "
                  f"{r['info_gain_threshold']:<15.2f} "
                  f"{r.get('total_clusters', 0):<12} "
                  f"{r.get('avg_cluster_size', 0):<12.2f} "
                  f"{r.get('total_samples', 0):<12} "
                  f"{r.get('elapsed_time', 0):<12.2f}")
        else:
            print(f"{r['attack_similarity_threshold']:<15.2f} "
                  f"{r['info_gain_threshold']:<15.2f} "
                  f"{'ERROR':<12} "
                  f"{'-':<12} "
                  f"{'-':<12} "
                  f"{r.get('elapsed_time', 0):<12.2f}")
    
    return results


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="超参数敏感性实验")
    
    parser.add_argument(
        "--attack_similarity_thresholds",
        type=float,
        nargs="+",
        default=[0.3, 0.4, 0.5, 0.6, 0.7],
        help="attack_similarity_threshold 值列表"
    )
    parser.add_argument(
        "--info_gain_thresholds",
        type=float,
        nargs="+",
        default=[0.2, 0.3, 0.4, 0.5],
        help="info_gain_threshold 值列表"
    )
    parser.add_argument(
        "--attack_results_file",
        type=str,
        default="attack_result/attack_results_original.json",
        help="攻击结果文件路径"
    )
    parser.add_argument(
        "--small_batch",
        action="store_true",
        default=True,
        help="使用小批量（前100条样本）"
    )
    parser.add_argument(
        "--max_samples",
        type=int,
        default=None,
        help="最大样本数量（覆盖small_batch）"
    )
    parser.add_argument(
        "--llm_port",
        type=int,
        default=8030,
        help="LLM API端口"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="qwen-72b",
        help="LLM模型名称"
    )
    parser.add_argument(
        "--safety_projector_path",
        type=str,
        default="src/models/safety_projector_metric.pth",
        help="Safety Projector模型路径"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./hyperparameter_sensitivity_results",
        help="结果输出目录"
    )
    
    args = parser.parse_args()
    
    run_hyperparameter_sensitivity_experiment(
        attack_similarity_thresholds=args.attack_similarity_thresholds,
        info_gain_thresholds=args.info_gain_thresholds,
        attack_results_file=args.attack_results_file,
        small_batch=args.small_batch,
        max_samples=args.max_samples,
        llm_port=args.llm_port,
        model_name=args.model_name,
        safety_projector_path=args.safety_projector_path,
        output_dir=args.output_dir
    )


if __name__ == '__main__':
    main()
