#!/usr/bin/env python3
"""
SafetyProjector Loss Weight 敏感度分析脚本

分析不同的 cls_loss 权重对 SafetyProjector 训练的影响。
loss = loss_triplet + cls_weight * loss_cls
"""

import json
import sys
import os
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import time

# 添加 src 目录到路径
src_path = Path(__file__).parent / 'src'
sys.path.insert(0, str(src_path))

from SafetyProjector import (
    SafetyProjector,
    train_on_agent_align,
    prepare_training_triplets,
    parse_agent_align_data
)
from sentence_transformers import SentenceTransformer


def evaluate_model_performance(model, test_triplets, device):
    """
    评估模型性能
    
    Args:
        model: SafetyProjector 模型
        test_triplets: 测试三元组列表
        device: 设备
    
    Returns:
        dict: 包含准确率、loss 等指标
    """
    model.eval()
    
    criterion_triplet = torch.nn.TripletMarginLoss(margin=0.5, p=2)
    criterion_cls = torch.nn.CrossEntropyLoss()
    
    anchors = torch.tensor([t[0] for t in test_triplets], dtype=torch.float32).to(device)
    positives = torch.tensor([t[1] for t in test_triplets], dtype=torch.float32).to(device)
    negatives = torch.tensor([t[2] for t in test_triplets], dtype=torch.float32).to(device)
    
    with torch.no_grad():
        a_emb, a_logits = model(anchors)
        p_emb, p_logits = model(positives)
        n_emb, n_logits = model(negatives)
        
        # 计算 loss
        loss_triplet = criterion_triplet(a_emb, p_emb, n_emb)
        
        all_logits = torch.cat([a_logits, p_logits, n_logits], dim=0)
        label_a = torch.ones(anchors.size(0), dtype=torch.long, device=device)
        label_p = torch.ones(positives.size(0), dtype=torch.long, device=device)
        label_n = torch.zeros(negatives.size(0), dtype=torch.long, device=device)
        all_labels = torch.cat([label_a, label_p, label_n], dim=0)
        
        loss_cls = criterion_cls(all_logits, all_labels)
        
        # 计算准确率
        preds = torch.argmax(all_logits, dim=1)
        accuracy = (preds == all_labels).float().mean().item()
        
        # 计算分离度（harmful 和 benign 的 embedding 距离）
        harmful_emb = torch.cat([a_emb, p_emb], dim=0)
        benign_emb = n_emb
        
        # 计算中心点
        harmful_center = harmful_emb.mean(dim=0)
        benign_center = benign_emb.mean(dim=0)
        
        # 计算余弦相似度（越低越好，说明分离度越高）
        separation = torch.nn.functional.cosine_similarity(
            harmful_center.unsqueeze(0),
            benign_center.unsqueeze(0)
        ).item()
        
        # 计算类内距离（harmful 样本之间的距离，越小越好）
        harmful_distances = torch.nn.functional.pairwise_distance(
            harmful_emb, harmful_center.unsqueeze(0).expand_as(harmful_emb)
        ).mean().item()
        
        # 计算类间距离（harmful 和 benign 中心之间的距离，越大越好）
        inter_class_distance = torch.nn.functional.pairwise_distance(
            harmful_center.unsqueeze(0),
            benign_center.unsqueeze(0)
        ).item()
    
    return {
        'loss_triplet': float(loss_triplet.item()),
        'loss_cls': float(loss_cls.item()),
        'accuracy': accuracy,
        'separation': separation,  # 余弦相似度（越低越好）
        'harmful_intra_distance': harmful_distances,  # 类内距离（越小越好）
        'inter_class_distance': inter_class_distance  # 类间距离（越大越好）
    }


def analyze_loss_weight_sensitivity(
    data_path,
    cls_weight_values,
    input_dim=384,
    batch_size=32,
    epochs=15,
    lr=1e-3,
    margin=0.5,
    train_ratio=0.8,
    output_dir="./sensitivity_results/loss_weight"
):
    """
    分析 cls_weight 参数的敏感度
    
    Args:
        data_path: 训练数据路径（agent_align_data_v3.json）
        cls_weight_values: cls_weight 值列表，例如 [0.1, 0.5, 1.0, 2.0, 5.0]
        input_dim: embedding 维度
        batch_size: 批次大小
        epochs: 训练轮数
        lr: 学习率
        margin: triplet loss margin
        train_ratio: 训练集比例
        output_dir: 结果输出目录
    """
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 1. 加载数据
    print(f"\n📂 加载数据: {data_path}")
    if not os.path.exists(data_path):
        print(f"❌ 数据文件不存在: {data_path}")
        return
    
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    harmful_texts, benign_texts = parse_agent_align_data(data)
    print(f"✅ 解析数据: {len(harmful_texts)} 个有害样本, {len(benign_texts)} 个良性样本")
    
    if len(harmful_texts) == 0:
        print("❌ 没有找到有害样本")
        return
    
    # 2. 初始化 embedding 模型
    print(f"\n🔧 初始化 embedding 模型...")
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    input_dim_actual = embedding_model.get_sentence_embedding_dimension()
    print(f"✅ Embedding 维度: {input_dim_actual}")
    
    def embedding_func(text_list):
        return embedding_model.encode(text_list, show_progress_bar=False)
    
    # 3. 准备三元组
    print(f"\n⛏️ 挖掘训练三元组...")
    triplets = prepare_training_triplets(harmful_texts, benign_texts, embedding_func)
    print(f"✅ 生成了 {len(triplets)} 个三元组")
    
    # 4. 划分训练集和测试集
    import random
    random.seed(42)
    random.shuffle(triplets)
    split_idx = int(len(triplets) * train_ratio)
    train_triplets = triplets[:split_idx]
    test_triplets = triplets[split_idx:]
    print(f"📊 训练集: {len(train_triplets)} 个三元组, 测试集: {len(test_triplets)} 个三元组")
    
    # 5. 对每个 cls_weight 值进行训练和评估
    results = []
    
    for cls_weight in cls_weight_values:
        print(f"\n{'='*80}")
        print(f"🔬 测试 cls_weight = {cls_weight}")
        print(f"{'='*80}")
        
        # 创建保存路径
        model_save_path = os.path.join(output_dir, f"safety_projector_cls_weight_{cls_weight}.pth")
        
        # 训练模型
        print(f"\n🚀 开始训练 (cls_weight={cls_weight})...")
        start_time = time.time()
        
        try:
            model = train_on_agent_align(
                triplets=train_triplets,
                input_dim=input_dim_actual,
                batch_size=batch_size,
                epochs=epochs,
                lr=lr,
                margin=margin,
                cls_weight=cls_weight,
                save_path=model_save_path
            )
            
            training_time = time.time() - start_time
            print(f"✅ 训练完成 (耗时: {training_time:.2f}秒)")
            
            # 评估模型
            print(f"\n🔍 评估模型性能...")
            eval_results = evaluate_model_performance(model, test_triplets, device)
            
            result = {
                'cls_weight': cls_weight,
                'training_time': training_time,
                'loss_triplet': eval_results['loss_triplet'],
                'loss_cls': eval_results['loss_cls'],
                'accuracy': eval_results['accuracy'],
                'separation': eval_results['separation'],
                'harmful_intra_distance': eval_results['harmful_intra_distance'],
                'inter_class_distance': eval_results['inter_class_distance'],
                'model_path': model_save_path
            }
            
            results.append(result)
            
            print(f"\n✅ Cls Weight {cls_weight} 评估完成:")
            print(f"   - 准确率: {result['accuracy']:.4f}")
            print(f"   - Triplet Loss: {result['loss_triplet']:.4f}")
            print(f"   - Cls Loss: {result['loss_cls']:.4f}")
            print(f"   - 分离度 (余弦相似度): {result['separation']:.4f} (越低越好)")
            print(f"   - 类内距离: {result['harmful_intra_distance']:.4f} (越小越好)")
            print(f"   - 类间距离: {result['inter_class_distance']:.4f} (越大越好)")
            
        except Exception as e:
            print(f"❌ Cls Weight {cls_weight} 训练/评估失败: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # 6. 保存结果
    output_file = os.path.join(output_dir, "loss_weight_sensitivity_results.json")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"\n{'='*80}")
    print(f"📊 敏感度分析完成！")
    print(f"{'='*80}")
    print(f"\n结果已保存到: {output_file}")
    print(f"\n汇总结果:")
    print(f"{'Cls Weight':<12} {'准确率':<10} {'Triplet Loss':<15} {'Cls Loss':<12} "
          f"{'分离度':<10} {'类内距离':<12} {'类间距离':<12}")
    print(f"{'-'*90}")
    for r in results:
        print(f"{r['cls_weight']:<12.2f} {r['accuracy']:<10.4f} {r['loss_triplet']:<15.4f} "
              f"{r['loss_cls']:<12.4f} {r['separation']:<10.4f} {r['harmful_intra_distance']:<12.4f} "
              f"{r['inter_class_distance']:<12.4f}")
    
    return results


def main():
    """主函数"""
    import argparse
    
    parser = argparse.ArgumentParser(description="SafetyProjector Loss Weight 敏感度分析")
    parser.add_argument(
        "--data_path",
        type=str,
        default="/path/to/agentharm/agent_align_data_v3.json",
        help="训练数据路径"
    )
    parser.add_argument(
        "--cls_weight_values",
        type=float,
        nargs="+",
        default=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
        help="要测试的 cls_weight 值列表"
    )
    parser.add_argument(
        "--input_dim",
        type=int,
        default=384,
        help="Embedding 维度（会自动检测，此参数作为备用）"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="批次大小"
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=15,
        help="训练轮数"
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-3,
        help="学习率"
    )
    parser.add_argument(
        "--margin",
        type=float,
        default=0.5,
        help="Triplet loss margin"
    )
    parser.add_argument(
        "--train_ratio",
        type=float,
        default=0.8,
        help="训练集比例"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./sensitivity_results/loss_weight",
        help="结果输出目录"
    )
    
    args = parser.parse_args()
    
    analyze_loss_weight_sensitivity(
        data_path=args.data_path,
        cls_weight_values=args.cls_weight_values,
        input_dim=args.input_dim,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        margin=args.margin,
        train_ratio=args.train_ratio,
        output_dir=args.output_dir
    )


if __name__ == '__main__':
    main()
