#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Triplet Loss 调试脚本
用于分析为什么 triplet loss 为 0
"""

import torch
import torch.nn.functional as F
import numpy as np
from fastreid.modeling.losses.triplet_loss import triplet_loss, hard_example_mining, weighted_example_mining
from fastreid.modeling.losses.utils import euclidean_dist, cosine_dist

def debug_triplet_loss():
    """调试 triplet loss 函数"""
    print("=== Triplet Loss 调试 ===")
    
    # 模拟一些特征和标签
    batch_size = 8
    feat_dim = 768
    
    # 创建模拟数据
    features = torch.randn(batch_size, feat_dim)
    targets = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])  # 4个身份，每个2个样本
    
    print(f"特征形状: {features.shape}")
    print(f"标签: {targets}")
    
    # 测试不同的参数设置
    test_configs = [
        {"margin": 0.3, "norm_feat": False, "hard_mining": True},
        {"margin": 0.3, "norm_feat": False, "hard_mining": False},
        {"margin": 0.1, "norm_feat": False, "hard_mining": True},
        {"margin": 0.3, "norm_feat": True, "hard_mining": True},
    ]
    
    for i, config in enumerate(test_configs):
        print(f"\n--- 配置 {i+1}: {config} ---")
        
        # 计算距离矩阵
        if config["norm_feat"]:
            dist_mat = cosine_dist(features, features)
        else:
            dist_mat = euclidean_dist(features, features)
        
        print(f"距离矩阵形状: {dist_mat.shape}")
        print(f"距离矩阵范围: [{dist_mat.min():.4f}, {dist_mat.max():.4f}]")
        
        # 创建正负样本掩码
        N = dist_mat.size(0)
        is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
        is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
        
        print(f"正样本掩码形状: {is_pos.shape}")
        print(f"正样本对数量: {is_pos.sum().item()}")
        print(f"负样本对数量: {is_neg.sum().item()}")
        
        # 测试 hard mining
        if config["hard_mining"]:
            dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
        else:
            dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
        
        print(f"正样本距离范围: [{dist_ap.min():.4f}, {dist_ap.max():.4f}]")
        print(f"负样本距离范围: [{dist_an.min():.4f}, {dist_an.max():.4f}]")
        
        # 计算损失
        y = dist_an.new().resize_as_(dist_an).fill_(1)
        
        if config["margin"] > 0:
            loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=config["margin"])
        else:
            loss = F.soft_margin_loss(dist_an - dist_ap, y)
        
        print(f"Triplet Loss: {loss.item():.6f}")
        
        # 分析为什么损失为 0
        if loss.item() == 0:
            print("*** 损失为 0 的原因分析 ***")
            margin_satisfied = (dist_an - dist_ap - config["margin"]).sum()
            print(f"满足 margin 条件的样本数: {margin_satisfied.item()}")
            
            # 检查是否有有效的正负样本对
            valid_pairs = (is_pos.sum(dim=1) > 0) & (is_neg.sum(dim=1) > 0)
            print(f"有有效正负样本对的 anchor 数量: {valid_pairs.sum().item()}")

def debug_real_data():
    """调试真实数据"""
    print("\n=== 真实数据调试 ===")
    
    # 模拟更真实的情况
    batch_size = 8
    feat_dim = 768
    
    # 创建更真实的特征（相似身份的特征更接近）
    features = torch.randn(batch_size, feat_dim)
    targets = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
    
    # 让相同身份的特征更相似
    for i in range(0, batch_size, 4):
        base_feat = features[i]
        for j in range(i, min(i+4, batch_size)):
            features[j] = base_feat + 0.1 * torch.randn_like(base_feat)
    
    print(f"调整后的特征相似度:")
    dist_mat = euclidean_dist(features, features)
    for i in range(0, batch_size, 4):
        for j in range(i, min(i+4, batch_size)):
            if i != j:
                print(f"  身份 {targets[i].item()} 样本间距离: {dist_mat[i,j]:.4f}")
    
    # 测试 triplet loss
    loss = triplet_loss(features, targets, margin=0.3, norm_feat=False, hard_mining=True)
    print(f"真实数据 Triplet Loss: {loss.item():.6f}")

if __name__ == "__main__":
    debug_triplet_loss()
    debug_real_data()
