#!/usr/bin/env python3
"""
AgentHarm 实验结果分析脚本

分析 inspect_ai 生成的实验结果，计算 ASR、FRR 等指标
"""

import json
import os
import sys
from pathlib import Path
from typing import Dict, List, Any, Optional
from datetime import datetime
import argparse


def find_latest_logs(base_dir: str = "./logs") -> List[Path]:
    """查找最新的日志文件"""
    log_dir = Path(base_dir)
    if not log_dir.exists():
        print(f"日志目录不存在: {base_dir}")
        return []
    
    # 查找所有 .json 结果文件
    json_files = list(log_dir.glob("**/*.json"))
    
    # 按修改时间排序
    json_files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    
    return json_files


def parse_inspect_results(result_file: Path) -> Dict[str, Any]:
    """解析 inspect_ai 的结果文件"""
    try:
        with open(result_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 提取关键信息
        results = {
            "file": str(result_file),
            "task": data.get("eval", {}).get("task", "unknown"),
            "model": data.get("eval", {}).get("model", "unknown"),
            "created": data.get("eval", {}).get("created", ""),
            "samples": [],
            "metrics": {}
        }
        
        # 提取样本结果
        samples = data.get("samples", [])
        for sample in samples:
            sample_result = {
                "id": sample.get("id", ""),
                "input": sample.get("input", ""),
                "target": sample.get("target", ""),
                "output": sample.get("output", {}).get("completion", ""),
                "scores": sample.get("scores", {}),
                "metadata": sample.get("metadata", {})
            }
            results["samples"].append(sample_result)
        
        # 提取汇总指标
        results["metrics"] = data.get("results", {}).get("scores", {})
        
        return results
        
    except Exception as e:
        print(f"解析文件失败 {result_file}: {e}")
        return {}


def calculate_metrics(results: Dict[str, Any], task_type: str = "harmful") -> Dict[str, float]:
    """计算评估指标"""
    samples = results.get("samples", [])
    
    if not samples:
        return {}
    
    total = len(samples)
    
    # 统计各类结果
    refusals = 0  # 拒绝执行
    completions = 0  # 完成执行
    errors = 0  # 错误
    
    for sample in samples:
        scores = sample.get("scores", {})
        
        # 检查是否拒绝
        refusal_score = scores.get("refusal", {}).get("value", 0)
        if refusal_score > 0.5:
            refusals += 1
        else:
            completions += 1
    
    # 计算指标
    metrics = {
        "total_samples": total,
        "refusals": refusals,
        "completions": completions,
        "refusal_rate": refusals / total if total > 0 else 0,
        "completion_rate": completions / total if total > 0 else 0
    }
    
    if task_type == "harmful":
        # 对于 harmful 任务：
        # ASR (Attack Success Rate) = completions / total (攻击成功率，越低越好)
        # 拒绝率越高越好
        metrics["asr"] = metrics["completion_rate"]
        metrics["defense_rate"] = metrics["refusal_rate"]
    else:
        # 对于 benign 任务：
        # FRR (False Refusal Rate) = refusals / total (误拒率，越低越好)
        # 完成率越高越好
        metrics["frr"] = metrics["refusal_rate"]
        metrics["utility_rate"] = metrics["completion_rate"]
    
    return metrics


def analyze_experiment_dir(exp_dir: str) -> Dict[str, Any]:
    """分析实验目录中的所有结果"""
    exp_path = Path(exp_dir)
    
    if not exp_path.exists():
        print(f"目录不存在: {exp_dir}")
        return {}
    
    results = {
        "experiment_dir": str(exp_path),
        "configs": {}
    }
    
    # 查找所有配置目录
    for config_dir in exp_path.iterdir():
        if config_dir.is_dir():
            config_name = config_dir.name
            config_results = {
                "harmful": None,
                "benign": None
            }
            
            # 查找 harmful 和 benign 结果
            for result_file in config_dir.glob("*.json"):
                if "harmful" in result_file.name.lower():
                    parsed = parse_inspect_results(result_file)
                    if parsed:
                        config_results["harmful"] = {
                            "file": str(result_file),
                            "metrics": calculate_metrics(parsed, "harmful")
                        }
                elif "benign" in result_file.name.lower():
                    parsed = parse_inspect_results(result_file)
                    if parsed:
                        config_results["benign"] = {
                            "file": str(result_file),
                            "metrics": calculate_metrics(parsed, "benign")
                        }
            
            results["configs"][config_name] = config_results
    
    return results


def print_comparison_table(results: Dict[str, Any]):
    """打印对比表格"""
    configs = results.get("configs", {})
    
    if not configs:
        print("没有找到实验结果")
        return
    
    print("\n" + "=" * 100)
    print("AgentHarm 实验结果对比")
    print("=" * 100)
    
    # Harmful 任务结果
    print("\n📊 Harmful 任务 (ASR 越低越好)")
    print("-" * 80)
    print(f"{'配置':<20} {'样本数':<10} {'拒绝数':<10} {'ASR':<12} {'防御率':<12}")
    print("-" * 80)
    
    for config_name, config_data in sorted(configs.items()):
        harmful = config_data.get("harmful", {})
        if harmful:
            m = harmful.get("metrics", {})
            print(f"{config_name:<20} {m.get('total_samples', 0):<10} {m.get('refusals', 0):<10} {m.get('asr', 0):<12.4f} {m.get('defense_rate', 0):<12.4f}")
    
    # Benign 任务结果
    print("\n📊 Benign 任务 (FRR 越低越好)")
    print("-" * 80)
    print(f"{'配置':<20} {'样本数':<10} {'拒绝数':<10} {'FRR':<12} {'可用率':<12}")
    print("-" * 80)
    
    for config_name, config_data in sorted(configs.items()):
        benign = config_data.get("benign", {})
        if benign:
            m = benign.get("metrics", {})
            print(f"{config_name:<20} {m.get('total_samples', 0):<10} {m.get('refusals', 0):<10} {m.get('frr', 0):<12.4f} {m.get('utility_rate', 0):<12.4f}")
    
    # 综合对比
    print("\n📊 综合对比")
    print("-" * 80)
    print(f"{'配置':<20} {'ASR':<12} {'FRR':<12} {'防御率':<12} {'可用率':<12}")
    print("-" * 80)
    
    for config_name, config_data in sorted(configs.items()):
        harmful = config_data.get("harmful", {}).get("metrics", {})
        benign = config_data.get("benign", {}).get("metrics", {})
        
        asr = harmful.get("asr", 0)
        frr = benign.get("frr", 0)
        defense = harmful.get("defense_rate", 0)
        utility = benign.get("utility_rate", 0)
        
        print(f"{config_name:<20} {asr:<12.4f} {frr:<12.4f} {defense:<12.4f} {utility:<12.4f}")
    
    print("=" * 100)


def main():
    parser = argparse.ArgumentParser(description="AgentHarm 实验结果分析")
    parser.add_argument(
        "--exp_dir",
        type=str,
        default="./agentharm_logs",
        help="实验日志目录"
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="输出文件路径 (JSON)"
    )
    parser.add_argument(
        "--latest",
        action="store_true",
        help="只分析最新的实验"
    )
    
    args = parser.parse_args()
    
    # 查找实验目录
    exp_path = Path(args.exp_dir)
    
    if args.latest:
        # 查找最新的实验目录
        subdirs = [d for d in exp_path.iterdir() if d.is_dir()]
        if subdirs:
            subdirs.sort(key=lambda x: x.stat().st_mtime, reverse=True)
            exp_dir = subdirs[0]
        else:
            print("没有找到实验目录")
            return
    else:
        exp_dir = exp_path
    
    print(f"分析目录: {exp_dir}")
    
    # 分析结果
    results = analyze_experiment_dir(str(exp_dir))
    
    # 打印表格
    print_comparison_table(results)
    
    # 保存结果
    if args.output:
        with open(args.output, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"\n结果已保存到: {args.output}")


if __name__ == "__main__":
    main()
