# topology_tools.py

import numpy as np
import torch
import torch.nn.functional as F

# ================================================================================================
# Part 1: Core Homology Calculation Tools
# ================================================================================================

class UnionFind:
    '''
    并查集的实现。该类默认执行路径压缩。
    它使用整数存储一个不相交集，并假设顶点是从零开始索引的。
    '''
    def __init__(self, n_vertices):
        self._parent = np.arange(n_vertices, dtype=int)

    def find(self, u):
        if self._parent[u] == u:
            return u
        else:
            self._parent[u] = self.find(self._parent[u])
            return self._parent[u]

    def merge(self, u, v):
        if u != v:
            self._parent[self.find(u)] = self.find(v)

class PersistentHomologyCalculation:
    '''
    计算0维持久同调。
    该实现基于Vietoris-Rips复合物的算法。
    '''
    def __call__(self, matrix):
        n_vertices = matrix.shape[0]
        uf = UnionFind(n_vertices)
        triu_indices = np.triu_indices_from(matrix, k=1)
        edge_weights = matrix[triu_indices]
        edge_indices = np.argsort(edge_weights, kind='stable')
        persistence_pairs = []
        for edge_index in edge_indices:
            u = triu_indices[0][edge_index]
            v = triu_indices[1][edge_index]
            if uf.find(u) == uf.find(v):
                continue
            uf.merge(u, v)
            if u < v:
                persistence_pairs.append((u, v))
            else:
                persistence_pairs.append((v, u))
        return np.array(persistence_pairs), np.array([])

# ================================================================================================
# Part 2: Helper functions
# ================================================================================================

def compute_distance_matrix(x, p=2):
    """
    计算一个批次特征的成对距离矩阵。
    """
    if x.dim() == 2:
        x_flat = x.view(x.size(0), -1)
        distances = torch.cdist(x_flat, x_flat, p=p)
    elif x.dim() == 3:
        distances = torch.cdist(x, x, p=p)
    else:
        raise ValueError("Input tensor must have 2 or 3 dimensions")
    return distances

def get_pairings(distances):
    """从距离矩阵中获取持久同调对。"""
    signature_calculator = PersistentHomologyCalculation()
    pairs_0, _ = signature_calculator(distances.detach().cpu().numpy())
    return pairs_0

# ================================================================================================
# Part 3: Supervised Topological Loss (for Stage 2)
# ================================================================================================

def calculate_supervised_topological_loss(features: torch.Tensor, labels: torch.Tensor, lambda_inter: float = 0.5):
    """
    (用于第二阶段) 计算监督性拓扑损失，包含“类内紧凑”和“类间分离”两个部分。
    """
    device = features.device
    unique_labels = torch.unique(labels)
    total_intra_loss = 0.0
    num_intra_calcs = 0
    for label in unique_labels:
        class_indices = torch.where(labels == label)[0]
        if len(class_indices) > 1:
            class_features = features[class_indices]
            distance_matrix = compute_distance_matrix(class_features.unsqueeze(0)).squeeze(0)
            pairs_0 = get_pairings(distance_matrix)
            if pairs_0.size > 0:
                death_indices = torch.from_numpy(pairs_0).to(device, non_blocking=True)
                death_times = distance_matrix[death_indices[:, 0], death_indices[:, 1]]
                total_intra_loss += torch.sum(death_times)
                num_intra_calcs += 1
    avg_intra_loss = total_intra_loss / num_intra_calcs if num_intra_calcs > 0 else torch.tensor(0.0, device=device)

    total_inter_loss = 0.0
    if len(unique_labels) > 1:
        full_distance_matrix = compute_distance_matrix(features.unsqueeze(0)).squeeze(0)
        full_pairs_0 = get_pairings(full_distance_matrix)
        if full_pairs_0.size > 0:
            births, deaths = full_pairs_0[:, 0], full_pairs_0[:, 1]
            
            # [FIX] 在使用numpy数组索引GPU张量前，先将其转换为GPU张量
            births_tensor = torch.from_numpy(births).to(device, non_blocking=True)
            deaths_tensor = torch.from_numpy(deaths).to(device, non_blocking=True)
            
            birth_labels = labels[births_tensor]
            death_labels = labels[deaths_tensor]
            
            inter_class_mask = birth_labels != death_labels
            
            # 这里也需要使用张量进行索引
            inter_class_births = births_tensor[inter_class_mask]
            inter_class_deaths = deaths_tensor[inter_class_mask]

            if len(inter_class_births) > 0:
                inter_death_times = full_distance_matrix[inter_class_births, inter_class_deaths]
                total_inter_loss = -torch.mean(inter_death_times)
                
    return avg_intra_loss + lambda_inter * total_inter_loss