#!/usr/bin/env python3
"""
熵阈值消融实验脚本
批量运行不同阈值下的 memory tree 构建，并收集统计信息
支持并行建树（三个一组）
"""

import os
import sys
import json
import subprocess
from datetime import datetime
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing

# 添加 src 目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))

from memory_defender import main_evolution
from risk_tree import RiskTree

def run_single_experiment(threshold, small_batch=True, output_dir="./threshold_ablation_results", process_id=None):
    """
    运行单个阈值的实验
    
    Args:
        threshold: 熵阈值
        small_batch: 是否使用小批量（用于快速测试）
        output_dir: 输出目录
        process_id: 进程ID（用于区分并行进程的输出文件）
    
    Returns:
        dict: 实验结果统计信息
    """
    # 转换为绝对路径
    output_dir = os.path.abspath(output_dir)
    original_cwd = os.getcwd()
    temp_work_dir = None
    
    # 为每个进程创建独立的临时工作目录，避免文件冲突
    if process_id is not None:
        temp_work_dir = os.path.join(output_dir, f"temp_work_{process_id}")
        os.makedirs(temp_work_dir, exist_ok=True)
        os.chdir(temp_work_dir)
    
    try:
    print(f"\n{'='*60}")
        print(f"[进程 {process_id}] 运行实验: threshold = {threshold:.2f}")
    print(f"{'='*60}")
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
        # 调用 main_evolution，不加载已有文件，从头构建
        main_evolution(
            load_existing_pkl=None,
            skip_attack_evolution=False,
            small_batch=small_batch,
            enable_attack_enhancement=True,
            entropy_threshold=threshold
        )
        
        # 查找最新生成的 pkl 文件（在当前工作目录）
        pkl_files = sorted(Path('.').glob('final_memory_*.pkl'), key=os.path.getmtime, reverse=True)
        if pkl_files:
            latest_pkl = os.path.abspath(str(pkl_files[0]))
            
            # 加载 tree 并获取统计信息
            # 查找 safety_projector 模型路径（从项目根目录）
            safety_projector_path = os.path.join(original_cwd, "src", "models", "safety_projector_metric.pth")
            if not os.path.exists(safety_projector_path):
                safety_projector_path = os.path.join(original_cwd, "models", "safety_projector_metric.pth")
            
            tree = RiskTree.load(latest_pkl, safety_projector_path=safety_projector_path)
            stats = tree.count_clusters()
            
            # 添加实验元数据
            stats['experiment'] = {
                'threshold': threshold,
                'timestamp': datetime.now().isoformat(),
                'pkl_file': latest_pkl,
                'small_batch': small_batch,
                'process_id': process_id
            }
            
            # 保存结果到输出目录（使用绝对路径）
            result_file = os.path.join(output_dir, f"threshold_{threshold:.2f}_stats.json")
            with open(result_file, 'w', encoding='utf-8') as f:
                json.dump(stats, f, indent=2, ensure_ascii=False)
            
            # 移动 pkl 文件到输出目录
            output_pkl = os.path.join(output_dir, f"memory_tree_threshold_{threshold:.2f}.pkl")
            if os.path.exists(latest_pkl):
                import shutil
                shutil.move(latest_pkl, output_pkl)
            stats['experiment']['pkl_file'] = output_pkl
            
            print(f"\n✓ [进程 {process_id}] 实验完成: threshold = {threshold:.2f}")
            print(f"  - Clusters: {stats['total_clusters']}")
            print(f"  - Categories: {stats['total_categories']}")
            print(f"  - 结果保存到: {result_file}")
            
            return stats
        else:
            print(f"⚠️ [进程 {process_id}] 警告: 未找到生成的 pkl 文件")
            return None
            
    except Exception as e:
        print(f"✗ [进程 {process_id}] 实验失败: threshold = {threshold:.2f}, 错误: {e}")
        import traceback
        traceback.print_exc()
        return None
    finally:
        # 恢复原始工作目录
        if process_id is not None and temp_work_dir:
            os.chdir(original_cwd)
            # 清理临时工作目录
            if os.path.exists(temp_work_dir):
                import shutil
                try:
                    shutil.rmtree(temp_work_dir)
                except:
                    pass


def main():
    """主函数"""
    import argparse
    
    parser = argparse.ArgumentParser(description='熵阈值消融实验（支持并行建树）')
    # 建议取值：从粗粒度到细粒度，覆盖主要变化区间
    # 0.3-0.5: 低阈值，容易合并，cluster 较少
    # 0.6: 默认值，平衡点
    # 0.7-0.9: 高阈值，容易分裂，cluster 较多
    parser.add_argument('--thresholds', type=float, nargs='+', 
                        default=[0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.8, 0.9],
                        help='要测试的阈值列表（默认: 0.3 0.4 0.5 0.55 0.6 0.65 0.7 0.8 0.9，在0.5-0.7区间更密集）')
    parser.add_argument('--output_dir', type=str, default='./threshold_ablation_results',
                        help='输出目录（默认: ./threshold_ablation_results）')
    parser.add_argument('--small_batch', action='store_true', default=True,
                        help='使用小批量数据（默认: True，用于快速测试）')
    parser.add_argument('--no_small_batch', dest='small_batch', action='store_false',
                        help='不使用小批量数据（使用全部数据）')
    parser.add_argument('--parallel', action='store_true', default=True,
                        help='启用并行建树（默认: True）')
    parser.add_argument('--no_parallel', dest='parallel', action='store_false',
                        help='禁用并行建树（串行运行）')
    parser.add_argument('--max_workers', type=int, default=3,
                        help='并行进程数（默认: 3，三个一组并行建树）')
    
    args = parser.parse_args()
    
    print("="*60)
    print("熵阈值消融实验")
    print("="*60)
    print(f"阈值列表: {args.thresholds}")
    print(f"输出目录: {args.output_dir}")
    print(f"小批量模式: {args.small_batch}")
    print(f"并行模式: {args.parallel}")
    if args.parallel:
        print(f"并行进程数: {args.max_workers}")
    print("="*60)
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 运行所有实验
    all_results = []
    
    if args.parallel:
        # 并行模式：使用 ProcessPoolExecutor
        # 设置启动方法（Linux 默认使用 fork，Windows 使用 spawn）
        if sys.platform != 'win32':
            try:
                multiprocessing.set_start_method('fork', force=True)
            except RuntimeError:
                pass  # 已经设置过了
        
        print(f"\n🚀 使用并行模式，{args.max_workers} 个进程同时建树...")
        with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
            # 提交所有任务
            future_to_threshold = {
                executor.submit(run_single_experiment, threshold, args.small_batch, args.output_dir, idx): threshold
                for idx, threshold in enumerate(args.thresholds)
            }
            
            # 收集结果
            completed = 0
            total = len(args.thresholds)
            for future in as_completed(future_to_threshold):
                threshold = future_to_threshold[future]
                completed += 1
                try:
                    result = future.result()
                    if result:
                        all_results.append(result)
                        print(f"\n✅ [{completed}/{total}] 完成阈值 {threshold:.2f}")
                except Exception as e:
                    print(f"\n❌ [{completed}/{total}] 阈值 {threshold:.2f} 失败: {e}")
    else:
        # 串行模式
        print(f"\n🚀 使用串行模式...")
        for idx, threshold in enumerate(args.thresholds):
            result = run_single_experiment(threshold, small_batch=args.small_batch, output_dir=args.output_dir, process_id=idx)
        if result:
            all_results.append(result)
    
    # 按阈值排序结果
    all_results.sort(key=lambda x: x['experiment']['threshold'])
    
    # 保存汇总结果
    summary_file = os.path.join(args.output_dir, 'threshold_ablation_summary.json')
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)
    
    # 打印汇总表格
    print("\n" + "="*60)
    print("实验结果汇总")
    print("="*60)
    print(f"{'Threshold':<12} {'Categories':<12} {'Clusters':<12} {'Avg/Category':<15}")
    print("-"*60)
    for result in all_results:
        threshold = result['experiment']['threshold']
        categories = result['total_categories']
        clusters = result['total_clusters']
        avg = clusters / categories if categories > 0 else 0
        print(f"{threshold:<12.2f} {categories:<12} {clusters:<12} {avg:<15.2f}")
    print("="*60)
    print(f"\n汇总结果已保存到: {summary_file}")


if __name__ == '__main__':
    main()
