import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import os
import random
import numpy as np
from tqdm import tqdm
import json
from sentence_transformers import SentenceTransformer

# ==========================================
# 1. 模型定义：基于原型的度量学习 (Metric Learning)
# ==========================================
class SafetyProjector(nn.Module):
    def __init__(self, input_dim=1536, device=None): 
        super(SafetyProjector, self).__init__()
        
        # [Part 1] Embedding Projector: 语义映射与降维
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2), # 稍微增加 Dropout 防止过拟合
            nn.Linear(512, 128) # 压缩到 128 维
        )
        
        # [Part 2] Learnable Prototypes (关键修改)
        # 我们不再训练一个平面切分，而是训练两个"中心点"
        # Shape: (2, 128) -> [Row 0: Benign Proto, Row 1: Harmful Proto]
        self.prototypes = nn.Parameter(torch.randn(2, 128))
        
        # [Part 3] Learnable Temperature
        # 初始值设为 1.0 (或更低如 0.1)
        # T 越小，分布越尖锐；T 越大，分布越平滑 (我们希望它平滑，所以让它可学习)
        self.temperature = nn.Parameter(torch.ones(1) * 0.5) 

        if device is not None:
            self.to(device)

    def forward(self, x):
        # 1. 计算 Query Embedding
        feat = self.net(x)
        # [重要] 归一化 Query，确保我们在比较角度
        query_emb = F.normalize(feat, p=2, dim=1)
        
        # 2. 归一化 Prototypes
        protos_norm = F.normalize(self.prototypes, p=2, dim=1)
        
        # 3. 计算 Cosine Similarity
        # Matrix Multiply: (Batch, 128) x (128, 2) -> (Batch, 2)
        # col 0: similarity to Benign, col 1: similarity to Harmful
        similarity = torch.mm(query_emb, protos_norm.T)
        
        # 4. 应用温度系数缩放
        # scaling 后的 logits 用于 CrossEntropy
        scaled_logits = similarity / self.temperature
        
        return query_emb, scaled_logits

# ==========================================
# 2. 训练循环：CrossEntropy (Prototype) + Triplet
# ==========================================
def train_on_agent_align(triplets, input_dim=384, batch_size=32, epochs=20, lr=1e-3, margin=0.5, cls_weight=1.0, save_path="./models/safety_projector.pth"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
    
    # 准备数据
    print(f"📊 Preparing {len(triplets)} triplets...")
    anchors = torch.tensor([t[0] for t in triplets], dtype=torch.float32)   # Harmful
    positives = torch.tensor([t[1] for t in triplets], dtype=torch.float32) # Harmful
    negatives = torch.tensor([t[2] for t in triplets], dtype=torch.float32) # Benign
    
    dataset = TensorDataset(anchors, positives, negatives)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == 'cuda'))
    
    # 初始化模型 (使用新的 Metric 类)
    model = SafetyProjector(input_dim=input_dim, device=device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # [Loss 定义]
    criterion_triplet = nn.TripletMarginLoss(margin=margin, p=2) 
    criterion_cls = nn.CrossEntropyLoss() # 自动做 Softmax
    
    print(f"🚀 Training Metric Safety Projector (Prototype-based)...")
    
    best_loss = float('inf')
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        total_acc = 0
        num_batches = 0
        
        # 监控温度系数的变化
        current_temp = model.temperature.item()
        
        for a, p, n in loader:
            a = a.to(device, non_blocking=True)
            p = p.to(device, non_blocking=True)
            n = n.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            
            # --- Forward Pass ---
            # a, p 应该是 Harmful (Index 1)
            # n    应该是 Benign  (Index 0)
            a_emb, a_logits = model(a) 
            p_emb, p_logits = model(p)
            n_emb, n_logits = model(n)
            
            # --- 1. Triplet Loss (保持语义聚类) ---
            loss_triplet = criterion_triplet(a_emb, p_emb, n_emb)
            
            # --- 2. Prototype Classification Loss (距离度量) ---
            # 拼接所有 logits: [Batch, 2]
            all_logits = torch.cat([a_logits, p_logits, n_logits], dim=0)
            
            # 构造标签: 
            # Anchor(Harmful)=1, Positive(Harmful)=1, Negative(Benign)=0
            # 注意：Prototype index 0 是 Benign, 1 是 Harmful
            label_a = torch.ones(a.size(0), dtype=torch.long, device=device)
            label_p = torch.ones(p.size(0), dtype=torch.long, device=device)
            label_n = torch.zeros(n.size(0), dtype=torch.long, device=device)
            all_labels = torch.cat([label_a, label_p, label_n], dim=0)
            
            loss_cls = criterion_cls(all_logits, all_labels)
            
            # --- 3. 总 Loss ---
            # 这里 loss_cls 会拉动 embedding 靠近对应的 prototype
            # cls_weight 控制分类损失和对比学习损失的权重平衡
            loss = loss_triplet + cls_weight * loss_cls 
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            # 计算准确率
            # argmax(logits) 给出预测的类别索引 (0 或 1)
            preds = torch.argmax(all_logits, dim=1)
            acc = (preds == all_labels).float().mean()
            total_acc += acc.item()
            
            num_batches += 1
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        avg_acc = total_acc / num_batches if num_batches > 0 else 0
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'input_dim': input_dim,
                'prototypes': model.prototypes, # 保存原型方便后续分析
                'temperature': model.temperature
            }, save_path)
        
        if (epoch+1) % 1 == 0:
            print(f"Epoch {epoch+1}/{epochs}: Loss={avg_loss:.4f} | Acc={avg_acc*100:.1f}% | Temp={current_temp:.3f}")
            
    print(f"✅ Training Done. Best loss: {best_loss:.4f}")
    return model

# ==========================================
# 3. 数据挖掘逻辑 (保持不变)
# ==========================================
def prepare_training_triplets(harmful_texts, benign_texts, embedding_func):
    # ... (这部分代码直接复用你原来的即可，完全不需要变) ...
    # 为了完整性，这里简写，请确保你的原始代码包含 import numpy as np 等
    import numpy as np
    from tqdm import tqdm
    import random

    print("Encoding Harmful samples for mining...")
    h_embs = np.array(embedding_func(harmful_texts))
    print("Encoding Benign samples for mining...")
    b_embs = np.array(embedding_func(benign_texts))
    
    h_norm = h_embs / (np.linalg.norm(h_embs, axis=1, keepdims=True) + 1e-9)
    b_norm = b_embs / (np.linalg.norm(b_embs, axis=1, keepdims=True) + 1e-9)
    
    triplets = []
    print("⛏️ Mining Hard Triplets...")
    
    sim_h2b = np.dot(h_norm, b_norm.T) 
    sim_h2h = np.dot(h_norm, h_norm.T)
    np.fill_diagonal(sim_h2h, -1.0)

    for i in tqdm(range(len(h_embs))):
        anchor = h_embs[i]
        hard_neg_idx = np.argmax(sim_h2b[i])
        hard_negative = b_embs[hard_neg_idx]
        semantic_pos_idx = np.argmax(sim_h2h[i])
        semantic_positive = h_embs[semantic_pos_idx]
        
        triplets.append((anchor, semantic_positive, hard_negative))
        
        if i % 5 == 0: 
            rand_neg_idx = random.choice(range(len(b_embs)))
            triplets.append((anchor, semantic_positive, b_embs[rand_neg_idx]))

    print(f"✅ Generated {len(triplets)} semantic triplets.")
    return triplets

def parse_agent_align_data(dataset):
    # ... (复用原代码) ...
    harmful_texts = []
    benign_texts = []

    for record in dataset:
        is_harmful = False
        if 'harmful' in record.get('id', '') or record.get('category') == 'self_harm':
            is_harmful = True
        elif 'benign' in record.get('category', ''):
            is_harmful = False
        else:
            continue 

        user_content = ""
        if 'messages' in record:
            for msg in record['messages']:
                if msg['role'] == 'user':
                    user_content = msg['content']
                    break 
        
        if not user_content:
            continue

        if is_harmful:
            harmful_texts.append(user_content)
        else:
            benign_texts.append(user_content)

    print(f"✅ Parsed Data: {len(harmful_texts)} Harmful vs {len(benign_texts)} Benign")
    return harmful_texts, benign_texts

def main_train():
    import os
    
    # 1. 加载数据
    data_path = "../agent_align_data_v3.json" # 请确保路径正确
    if not os.path.exists(data_path):
        print(f"❌ Data file not found: {data_path}")
        return
    
    print(f"📂 Loading data from {data_path}...")
    with open(data_path, "r", encoding='utf-8') as f:
        data = json.load(f)
    
    harmful_texts, benign_texts = parse_agent_align_data(data)
    
    if len(harmful_texts) == 0:
        print("❌ No data found.")
        return
    
    # 2. 初始化 embedding
    print("🔧 Initializing embedding model...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    input_dim = embedding_model.get_sentence_embedding_dimension() 
    
    def embedding_func(text_list):
        return embedding_model.encode(text_list, show_progress_bar=False)
    
    # 3. 挖掘
    triplets = prepare_training_triplets(harmful_texts, benign_texts, embedding_func)
    
    # 4. 训练
    save_path = "./models/safety_projector_metric.pth"
    model = train_on_agent_align(
        triplets=triplets,
        input_dim=input_dim,
        batch_size=32,
        epochs=15, # Metric Learning 收敛通常比较快
        lr=1e-3,
        margin=0.5,
        save_path=save_path
    )
    
    print("🎉 Training completed!")
    return model

if __name__ == "__main__":
    main_train()