#!/usr/bin/env python3
"""
SafetyProjector 分类准确性验证脚本

用法:
    python validate_safety_projector.py --test_data test_queries.json --model_path src/models/safety_projector_metric.pth
"""

import json
import argparse
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
import sys
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report

# 添加 src 目录到路径
src_path = Path(__file__).parent / 'src'
sys.path.insert(0, str(src_path))

from SafetyProjector import SafetyProjector
from sentence_transformers import SentenceTransformer


def load_model(model_path, device=None):
    """加载 SafetyProjector 模型"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"📂 加载模型: {model_path}")
    print(f"   设备: {device}")
    
    # 加载 checkpoint 获取 input_dim
    checkpoint = torch.load(model_path, map_location=device)
    
    # 获取 input_dim（从 checkpoint 或默认值）
    if 'input_dim' in checkpoint:
        input_dim = checkpoint['input_dim']
    else:
        # 默认使用 all-MiniLM-L6-v2 的维度
        input_dim = 384
    
    # 创建模型
    model = SafetyProjector(input_dim=input_dim, device=device)
    
    # 加载权重
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.to(device)
    model.eval()
    
    print(f"✅ 模型加载成功 (input_dim={input_dim})")
    
    # 检查是否有 temperature
    if hasattr(model, 'temperature'):
        print(f"   温度系数: {model.temperature.item():.4f}")
    
    return model, device


def predict_query(model, embedding_model, query, device):
    """
    对单个查询进行预测
    
    Returns:
        predicted_label: 0 (benign) 或 1 (harmful)
        harmful_prob: harmful 的概率
        logits: 原始 logits
    """
    # 获取 embedding
    query_emb = embedding_model.encode([query], normalize_embeddings=True)
    query_tensor = torch.tensor(query_emb, dtype=torch.float32, device=device)
    
    # 预测
    with torch.no_grad():
        _, logits = model(query_tensor)
        # logits shape: (1, 2) -> [benign_score, harmful_score]
        # 注意：logits[0][0] 是 benign，logits[0][1] 是 harmful
        probs = torch.softmax(logits, dim=1)
        harmful_prob = probs[0][1].item()  # harmful 的概率
        predicted_label = torch.argmax(logits, dim=1).item()  # 0=benign, 1=harmful
    
    return predicted_label, harmful_prob, logits[0].cpu().numpy()


def evaluate_model(test_data_path, model_path, output_path=None):
    """评估 SafetyProjector 模型"""
    
    # 加载测试数据
    print(f"\n📂 加载测试数据: {test_data_path}")
    with open(test_data_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    
    print(f"✅ 加载了 {len(test_data)} 个测试样本")
    
    # 加载模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model, device = load_model(model_path, device)
    
    # 初始化 embedding 模型
    print(f"\n🔧 初始化 embedding 模型...")
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    input_dim = embedding_model.get_sentence_embedding_dimension()
    print(f"✅ Embedding 模型初始化完成 (dim={input_dim})")
    
    # 准备数据
    true_labels = []
    predicted_labels = []
    harmful_probs = []
    results = []
    
    print(f"\n🔍 开始预测...")
    for item in tqdm(test_data, desc="Processing queries"):
        query = item.get('user_query', '')
        true_label_str = item.get('label', '').lower()
        
        # 转换标签: "harmful" -> 1, "benign" -> 0
        if true_label_str == 'harmful':
            true_label = 1
        elif true_label_str == 'benign':
            true_label = 0
        else:
            print(f"⚠️ 跳过未知标签: {true_label_str} (ID: {item.get('id', 'unknown')})")
            continue
        
        # 预测
        pred_label, harmful_prob, logits = predict_query(model, embedding_model, query, device)
        
        true_labels.append(true_label)
        predicted_labels.append(pred_label)
        harmful_probs.append(harmful_prob)
        
        results.append({
            'id': item.get('id', ''),
            'query': query[:100] + '...' if len(query) > 100 else query,
            'true_label': true_label_str,
            'predicted_label': 'harmful' if pred_label == 1 else 'benign',
            'harmful_prob': harmful_prob,
            'correct': pred_label == true_label,
            'category': item.get('category', '')
        })
    
    # 计算指标
    print(f"\n📊 计算评估指标...")
    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels, average='binary', zero_division=0)
    recall = recall_score(true_labels, predicted_labels, average='binary', zero_division=0)
    f1 = f1_score(true_labels, predicted_labels, average='binary', zero_division=0)
    
    # 混淆矩阵
    cm = confusion_matrix(true_labels, predicted_labels)
    tn, fp, fn, tp = cm.ravel()
    
    # 按类别统计
    harmful_count = sum(true_labels)
    benign_count = len(true_labels) - harmful_count
    
    harmful_correct = sum(1 for i in range(len(true_labels)) if true_labels[i] == 1 and predicted_labels[i] == 1)
    benign_correct = sum(1 for i in range(len(true_labels)) if true_labels[i] == 0 and predicted_labels[i] == 0)
    
    harmful_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    harmful_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    benign_precision = tn / (tn + fn) if (tn + fn) > 0 else 0
    benign_recall = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    print(classification_report(true_labels, predicted_labels, 
                                target_names=['Benign', 'Harmful'], 
                                digits=4))
    
    # 保存结果
    if output_path:
        output_data = {
            'model_path': model_path,
            'test_data_path': test_data_path,
            'metrics': {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1_score': f1,
                'confusion_matrix': {
                    'tn': int(tn),
                    'fp': int(fp),
                    'fn': int(fn),
                    'tp': int(tp)
                },
                'harmful': {
                    'count': harmful_count,
                    'correct': harmful_correct,
                    'precision': harmful_precision,
                    'recall': harmful_recall
                },
                'benign': {
                    'count': benign_count,
                    'correct': benign_correct,
                    'precision': benign_precision,
                    'recall': benign_recall
                }
            },
            'results': results
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        
        print(f"\n💾 详细结果已保存到: {output_path}")
    
    # 显示一些错误案例
    print(f"\n📋 错误案例示例 (前10个):")
    error_count = 0
    for i, result in enumerate(results):
        if not result['correct']:
            error_count += 1
            if error_count <= 10:
                print(f"\n  错误 #{error_count}:")
                print(f"    ID: {result['id']}")
                print(f"    真实标签: {result['true_label']}")
                print(f"    预测标签: {result['predicted_label']} (概率: {result['harmful_prob']:.4f})")
                print(f"    类别: {result['category']}")
                print(f"    查询: {result['query']}")
    
    if error_count == 0:
        print("  ✅ 没有错误案例！")
    
    print(f"\n{'='*70}\n")
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SafetyProjector 分类准确性验证")
    
    parser.add_argument(
        "--test_data",
        type=str,
        default="test_queries.json",
        help="测试数据文件路径（默认: test_queries.json）"
    )
    
    parser.add_argument(
        "--model_path",
        type=str,
        default="src/models/safety_projector_metric.pth",
        help="SafetyProjector 模型路径（默认: src/models/safety_projector_metric.pth）"
    )
    
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="结果输出文件路径（可选，默认不保存）"
    )
    
    args = parser.parse_args()
    
    # 检查文件是否存在
    if not Path(args.test_data).exists():
        print(f"❌ 错误: 测试数据文件不存在: {args.test_data}")
        sys.exit(1)
    
    if not Path(args.model_path).exists():
        print(f"❌ 错误: 模型文件不存在: {args.model_path}")
        sys.exit(1)
    
    # 运行评估
    evaluate_model(
        test_data_path=args.test_data,
        model_path=args.model_path,
        output_path=args.output
    )
