import numpy as np

def indicator_matrix(input, n_classes=None):
    """
    快速生成 One-Hot 矩阵
    """
    if n_classes is None:
        n_classes = np.max(input) + 1
    H = np.zeros((len(input), n_classes))
    H[np.arange(len(input)), input] = 1
    return H

def get_dist_matrix(X, centers):
    """
    极速计算欧氏距离矩阵 (N, K)
    利用公式: ||x-c||^2 = x^2 + c^2 - 2xc
    """
    # X: (N, d), centers: (K, d)
    XX = np.sum(X**2, axis=1)[:, np.newaxis] # (N, 1)
    CC = np.sum(centers**2, axis=1)[np.newaxis, :] # (1, K)
    XC = X @ centers.T # (N, K)
    dist = XX + CC - 2 * XC
    return dist

def compute_imbalance(F):
    """
    计算数据集的全局不平衡度，用于初始化 p1
    """
    group_counts = np.sum(F, axis=0)
    total = np.sum(group_counts)
    if total == 0: return 0
    group_ratio = group_counts / total
    ideal = 1.0 / len(group_ratio)
    imbalance = np.sum((group_ratio - ideal) ** 2)
    return imbalance

def compute_mean_min_distance(dis):
    min_d = np.min(dis, axis=1)
    return np.mean(min_d)

def auto_adjust_p1(F, dis, base_p1):
    """
    自动调整公平性惩罚系数的量级
    """
    imbalance = compute_imbalance(F)
    mean_dist = compute_mean_min_distance(dis)
    
    p1 = base_p1 * (1 + imbalance) * (mean_dist)
    
    if p1 < 1e-6: p1 = base_p1 # Fallback
    
    return p1

def fair_kmeans(input_X, k, centers, F, base_p1, base_p2=None, max_iter=200):
    '''
    极速且收敛的 Fair K-Means (EM 算法版)
    
    input_X: (N, d+1) 数据，第0列是 Label (本函数不使用 Label，只用特征)
    k: 簇数量
    centers: 初始中心 (K, d)
    F: 敏感属性 One-Hot 矩阵 (N, m)
    base_p1: 公平性惩罚权重
    base_p2: (已弃用) 簇大小平衡权重
    '''
  
    X = input_X[:, 1:] # 特征部分 (N, d)
    N, d = X.shape
    centers = np.array(centers)
    

    group_indices = np.argmax(F, axis=1) 
    m = F.shape[1] # 敏感群体数量


    dist_matrix = get_dist_matrix(X, centers)
    

    real_p1 = auto_adjust_p1(F, dist_matrix, base_p1)
    

    avg_cluster_size = N / k
    real_p1 = real_p1 / (avg_cluster_size + 1e-6)


    labels = np.argmin(dist_matrix, axis=1)
    

    V_potential = np.zeros((k, m))

    for iter_count in range(max_iter):
        last_labels = labels.copy()

        # === E-Step: 并行分配 (Batch Assignment) ===
        # 目标：min (几何距离 + p1 * 势能)
        
        
        geo_dist = get_dist_matrix(X, centers)
        
       
        fair_cost = V_potential[:, group_indices].T 
        
       
        total_cost = geo_dist + (real_p1 * fair_cost)
        

        labels = np.argmin(total_cost, axis=1)

    
        if np.array_equal(labels, last_labels):
            break

        # === M-Step: 更新参数 (Centers & Potentials) ===
        
       
        label_onehot = indicator_matrix(labels, n_classes=k)
        
       
        cluster_counts = np.sum(label_onehot, axis=0)
        
     
        cluster_group_counts = label_onehot.T @ F 
        
 
        valid = cluster_counts > 0
        

        centers_sum = label_onehot.T @ X
        

        if np.any(valid):
            centers[valid] = centers_sum[valid] / cluster_counts[valid][:, np.newaxis]
        

        if np.any(~valid):
            random_idx = np.random.choice(N, size=np.sum(~valid))
            centers[~valid] = X[random_idx]

        
        
        cluster_sizes_expanded = cluster_counts[:, np.newaxis] # (K, 1)
        V_potential = (m * cluster_group_counts) - cluster_sizes_expanded


    final_dist = get_dist_matrix(X, centers)
    min_dists = np.min(final_dist, axis=1)
    
    sumd = np.zeros(k)
    for c in range(k):
        sumd[c] = np.sum(min_dists[labels == c])

    return labels, centers, sumd