#!/usr/bin/env python3
"""
Safety Projector 消融实验脚本

对 margin 和 lambda (cls_weight) 参数进行消融实验
使用 test_queries.json 数据集评估分类性能
"""

import os
import sys
import json
import time
import argparse
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# 添加 src 目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))

import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer


class NumpyEncoder(json.JSONEncoder):
    """处理 numpy 类型的 JSON 编码器"""
    def default(self, obj):
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)


def load_test_data(data_path: str) -> List[Dict]:
    """加载测试数据"""
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data


def load_training_data(data_path: str) -> Tuple[List[str], List[str]]:
    """加载训练数据，返回 harmful 和 benign 文本列表"""
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    harmful_texts = []
    benign_texts = []
    
    for item in data:
        if isinstance(item, dict):
            # 方式1: 通过 id 前缀判断 (agent_align_data_v3.json 格式)
            item_id = item.get('id', '')
            
            # 提取用户消息
            text = None
            messages = item.get('messages', [])
            if messages:
                for msg in messages:
                    if msg.get('role') == 'user':
                        text = msg.get('content', '')
                        break
            
            # 如果没有 messages，尝试其他字段
            if not text:
                text = item.get('user_query', '') or item.get('prompt', '') or item.get('text', '')
            
            if not text:
                continue
            
            # 判断标签
            if item_id.startswith('harmful'):
                harmful_texts.append(text)
            elif item_id.startswith('benign'):
                benign_texts.append(text)
            else:
                # 方式2: 通过 label 字段判断
                label = item.get('label', '').lower()
                if label == 'harmful':
                    harmful_texts.append(text)
                elif label in ['benign', 'safe']:
                    benign_texts.append(text)
    
    print(f"✅ 加载训练数据: {len(harmful_texts)} harmful, {len(benign_texts)} benign")
    return harmful_texts, benign_texts


def prepare_training_triplets(harmful_texts: List[str], benign_texts: List[str], 
                              embedding_func, max_triplets: int = 5000) -> List[Tuple]:
    """准备训练三元组"""
    print("📊 生成 embeddings...")
    harmful_embs = embedding_func(harmful_texts)
    benign_embs = embedding_func(benign_texts)
    
    print("🔧 挖掘三元组...")
    triplets = []
    
    # 简单的三元组挖掘：每个 harmful 样本配对另一个 harmful 和一个 benign
    for i in range(len(harmful_embs)):
        for j in range(len(harmful_embs)):
            if i != j:
                # 随机选择一个 benign 作为 negative
                neg_idx = np.random.randint(0, len(benign_embs))
                triplets.append((
                    harmful_embs[i],
                    harmful_embs[j],
                    benign_embs[neg_idx]
                ))
                
                if len(triplets) >= max_triplets:
                    break
        if len(triplets) >= max_triplets:
            break
    
    print(f"✅ 生成 {len(triplets)} 个三元组")
    return triplets


class SafetyProjectorAblation(nn.Module):
    """用于消融实验的 Safety Projector"""
    def __init__(self, input_dim=384, device=None):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128)
        )
        
        self.prototypes = nn.Parameter(torch.randn(2, 128))
        self.temperature = nn.Parameter(torch.ones(1) * 0.5)
        
        if device is not None:
            self.to(device)
    
    def forward(self, x):
        feat = self.net(x)
        query_emb = nn.functional.normalize(feat, p=2, dim=1)
        protos_norm = nn.functional.normalize(self.prototypes, p=2, dim=1)
        similarity = torch.mm(query_emb, protos_norm.T)
        scaled_logits = similarity / self.temperature
        return query_emb, scaled_logits
    
    def predict_proba(self, x):
        """预测有害概率"""
        with torch.no_grad():
            _, logits = self.forward(x)
            probs = torch.softmax(logits, dim=1)
            return probs[:, 1]  # 返回 harmful 概率


def train_safety_projector(triplets: List[Tuple], input_dim: int, 
                           margin: float, cls_weight: float,
                           epochs: int = 15, batch_size: int = 32, 
                           lr: float = 1e-3, device: str = 'cuda') -> SafetyProjectorAblation:
    """训练 Safety Projector"""
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    # 准备数据
    anchors = torch.tensor([t[0] for t in triplets], dtype=torch.float32)
    positives = torch.tensor([t[1] for t in triplets], dtype=torch.float32)
    negatives = torch.tensor([t[2] for t in triplets], dtype=torch.float32)
    
    dataset = torch.utils.data.TensorDataset(anchors, positives, negatives)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 初始化模型
    model = SafetyProjectorAblation(input_dim=input_dim, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Loss 定义
    criterion_triplet = nn.TripletMarginLoss(margin=margin, p=2)
    criterion_cls = nn.CrossEntropyLoss()
    
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        
        for a, p, n in loader:
            a = a.to(device)
            p = p.to(device)
            n = n.to(device)
            
            optimizer.zero_grad()
            
            a_emb, a_logits = model(a)
            p_emb, p_logits = model(p)
            n_emb, n_logits = model(n)
            
            # Triplet Loss
            loss_triplet = criterion_triplet(a_emb, p_emb, n_emb)
            
            # Classification Loss
            all_logits = torch.cat([a_logits, p_logits, n_logits], dim=0)
            label_a = torch.ones(a.size(0), dtype=torch.long, device=device)
            label_p = torch.ones(p.size(0), dtype=torch.long, device=device)
            label_n = torch.zeros(n.size(0), dtype=torch.long, device=device)
            all_labels = torch.cat([label_a, label_p, label_n], dim=0)
            
            loss_cls = criterion_cls(all_logits, all_labels)
            
            # 总 Loss
            loss = loss_triplet + cls_weight * loss_cls
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
    
    model.eval()
    return model


def evaluate_model(model: SafetyProjectorAblation, test_data: List[Dict], 
                   embedding_func, device: str = 'cuda', 
                   threshold: float = 0.5) -> Dict[str, Any]:
    """评估模型性能"""
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    # 准备测试数据
    queries = [item['user_query'] for item in test_data]
    labels = [1 if item['label'] == 'harmful' else 0 for item in test_data]
    
    # 生成 embeddings
    embeddings = embedding_func(queries)
    embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32).to(device)
    
    # 预测
    with torch.no_grad():
        probs = model.predict_proba(embeddings_tensor).cpu().numpy()
    
    predictions = (probs >= threshold).astype(int)
    
    # 计算指标
    tp = sum((p == 1 and l == 1) for p, l in zip(predictions, labels))
    tn = sum((p == 0 and l == 0) for p, l in zip(predictions, labels))
    fp = sum((p == 1 and l == 0) for p, l in zip(predictions, labels))
    fn = sum((p == 0 and l == 1) for p, l in zip(predictions, labels))
    
    accuracy = (tp + tn) / len(labels) if len(labels) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    # FRR 和 ASR
    frr = fp / (fp + tn) if (fp + tn) > 0 else 0  # False Refusal Rate
    asr = fn / (fn + tp) if (fn + tp) > 0 else 0  # Attack Success Rate
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'frr': frr,
        'asr': asr,
        'tp': tp,
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'threshold': threshold,
        'avg_harmful_prob': float(np.mean([p for p, l in zip(probs, labels) if l == 1])),
        'avg_benign_prob': float(np.mean([p for p, l in zip(probs, labels) if l == 0]))
    }


def run_ablation_experiment(margin_values: List[float], lambda_values: List[float],
                            training_data_path: str, test_data_path: str,
                            output_dir: str, epochs: int = 15,
                            max_triplets: int = 5000) -> Dict[str, Any]:
    """运行消融实验"""
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 加载数据
    print("📂 加载数据...")
    test_data = load_test_data(test_data_path)
    harmful_texts, benign_texts = load_training_data(training_data_path)
    
    # 初始化 embedding 模型
    print("🔧 初始化 embedding 模型...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    input_dim = embedding_model.get_sentence_embedding_dimension()
    
    def embedding_func(texts):
        return embedding_model.encode(texts, show_progress_bar=False)
    
    # 准备三元组
    print("🔧 准备训练三元组...")
    triplets = prepare_training_triplets(harmful_texts, benign_texts, embedding_func, max_triplets)
    
    # 运行消融实验
    results = {
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'config': {
            'margin_values': margin_values,
            'lambda_values': lambda_values,
            'epochs': epochs,
            'max_triplets': max_triplets,
            'test_samples': len(test_data),
            'training_harmful': len(harmful_texts),
            'training_benign': len(benign_texts)
        },
        'experiments': []
    }
    
    total_experiments = len(margin_values) * len(lambda_values)
    print(f"\n🚀 开始消融实验: {total_experiments} 组配置")
    print(f"   Margin 值: {margin_values}")
    print(f"   Lambda 值: {lambda_values}")
    print("=" * 80)
    
    exp_idx = 0
    for margin in margin_values:
        for cls_weight in lambda_values:
            exp_idx += 1
            print(f"\n[{exp_idx}/{total_experiments}] margin={margin}, lambda={cls_weight}")
            
            start_time = time.time()
            
            # 训练模型
            print("  训练中...")
            model = train_safety_projector(
                triplets=triplets,
                input_dim=input_dim,
                margin=margin,
                cls_weight=cls_weight,
                epochs=epochs,
                device=device
            )
            
            train_time = time.time() - start_time
            
            # 评估模型
            print("  评估中...")
            eval_start = time.time()
            metrics = evaluate_model(model, test_data, embedding_func, device)
            eval_time = time.time() - eval_start
            
            # 记录结果
            exp_result = {
                'margin': margin,
                'lambda': cls_weight,
                'train_time': train_time,
                'eval_time': eval_time,
                **metrics
            }
            results['experiments'].append(exp_result)
            
            print(f"  ✅ 完成: Acc={metrics['accuracy']:.4f}, F1={metrics['f1_score']:.4f}, "
                  f"FRR={metrics['frr']:.4f}, ASR={metrics['asr']:.4f}")
            
            # 保存模型
            model_path = os.path.join(output_dir, f"model_margin{margin}_lambda{cls_weight}.pth")
            torch.save(model.state_dict(), model_path)
    
    # 保存结果
    results_path = os.path.join(output_dir, 'ablation_results.json')
    with open(results_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False, cls=NumpyEncoder)
    
    print(f"\n✅ 结果已保存到: {results_path}")
    
    return results


def print_results_table(results: Dict[str, Any]):
    """打印结果表格"""
    experiments = results['experiments']
    
    print("\n" + "=" * 120)
    print("Safety Projector 消融实验结果")
    print("=" * 120)
    
    # 按 F1 排序
    sorted_exps = sorted(experiments, key=lambda x: x['f1_score'], reverse=True)
    
    print(f"\n{'Margin':<10} {'Lambda':<10} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'FRR':<10} {'ASR':<10}")
    print("-" * 100)
    
    for exp in sorted_exps:
        print(f"{exp['margin']:<10.2f} {exp['lambda']:<10.2f} {exp['accuracy']:<12.4f} "
              f"{exp['precision']:<12.4f} {exp['recall']:<12.4f} {exp['f1_score']:<12.4f} "
              f"{exp['frr']:<10.4f} {exp['asr']:<10.4f}")
    
    # 找出最佳配置
    best_f1 = max(experiments, key=lambda x: x['f1_score'])
    best_acc = max(experiments, key=lambda x: x['accuracy'])
    lowest_frr = min(experiments, key=lambda x: x['frr'])
    lowest_asr = min(experiments, key=lambda x: x['asr'])
    
    print("\n" + "=" * 80)
    print("最佳配置:")
    print(f"  最高 F1:     margin={best_f1['margin']}, lambda={best_f1['lambda']} (F1={best_f1['f1_score']:.4f})")
    print(f"  最高 Acc:    margin={best_acc['margin']}, lambda={best_acc['lambda']} (Acc={best_acc['accuracy']:.4f})")
    print(f"  最低 FRR:    margin={lowest_frr['margin']}, lambda={lowest_frr['lambda']} (FRR={lowest_frr['frr']:.4f})")
    print(f"  最低 ASR:    margin={lowest_asr['margin']}, lambda={lowest_asr['lambda']} (ASR={lowest_asr['asr']:.4f})")
    print("=" * 80)


def main():
    parser = argparse.ArgumentParser(description='Safety Projector 消融实验')
    parser.add_argument('--training_data', type=str, default='./agent_align_data_v3.json',
                        help='训练数据路径')
    parser.add_argument('--test_data', type=str, default='./test_queries.json',
                        help='测试数据路径')
    parser.add_argument('--output_dir', type=str, default='./safety_projector_ablation_results',
                        help='输出目录')
    parser.add_argument('--margins', type=str, default='0.1,0.3,0.5,0.7,1.0',
                        help='Margin 值列表 (逗号分隔)')
    parser.add_argument('--lambdas', type=str, default='0.1,0.5,1.0,2.0,5.0',
                        help='Lambda (cls_weight) 值列表 (逗号分隔)')
    parser.add_argument('--epochs', type=int, default=15,
                        help='训练轮数')
    parser.add_argument('--max_triplets', type=int, default=5000,
                        help='最大三元组数量')
    
    args = parser.parse_args()
    
    # 解析参数值
    margin_values = [float(x) for x in args.margins.split(',')]
    lambda_values = [float(x) for x in args.lambdas.split(',')]
    
    # 添加时间戳到输出目录
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_dir = f"{args.output_dir}_{timestamp}"
    
    print("=" * 80)
    print("Safety Projector 消融实验")
    print("=" * 80)
    print(f"训练数据: {args.training_data}")
    print(f"测试数据: {args.test_data}")
    print(f"输出目录: {output_dir}")
    print(f"Margin 值: {margin_values}")
    print(f"Lambda 值: {lambda_values}")
    print(f"训练轮数: {args.epochs}")
    print("=" * 80)
    
    # 运行实验
    results = run_ablation_experiment(
        margin_values=margin_values,
        lambda_values=lambda_values,
        training_data_path=args.training_data,
        test_data_path=args.test_data,
        output_dir=output_dir,
        epochs=args.epochs,
        max_triplets=args.max_triplets
    )
    
    # 打印结果表格
    print_results_table(results)


if __name__ == '__main__':
    main()
