import torch
import torch.nn.functional as F
import wandb
from cdc.utils.torch_clustering import PyTorchKMeans
from collections import Counter
from cdc.utils.evaluate_utils import get_predictions, hungarian_evaluate
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
import numpy as np
import random
from typing import Tuple
import torch.nn.init as init


def orth_train(W, n_samples, scale = 5, epochs=2000, use_relu = False):
    Z = W.clone().cuda()
    #Z = Z.detach().clone()
    Z.requires_grad = True
    W_ = W.clone().cuda()
    #W_ = W_.detach().clone()
    W_.requires_grad = True
    labels = torch.arange(0, n_samples).cuda()
    optimizer = torch.optim.SGD([Z, W_], lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0)
    criterion = torch.nn.CrossEntropyLoss()
    # with torch.enable_grad():
    for i in range(epochs):
        if use_relu:
            z = F.relu(Z)
        else:
            z = Z
        w = W_
        L2_z = F.normalize(z, dim=1)
        L2_w = F.normalize(w, dim=1)
        out = F.linear(L2_z, L2_w)
        loss = criterion(out * scale, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    return W_.detach()

def initialize_weights(cfg, model, cali_mlp, features, val_dataloader):
    features_zscore = (features - features.mean(1).reshape(-1, 1)) / features.std(1).reshape(-1, 1)
    #features_zscore = features.detach()
    features_zscore = F.normalize(features_zscore, dim=1)

    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False, random_state=0)
    proto_label = KMeans_512.fit_predict(features_zscore)
    W1 = KMeans_512.cluster_centers_

    #linear(512,512)
    H = torch.mm(features, W1.T)
    #BN
    # H = (H - H.mean(0)) / H.std(0)
    H = model.module.cluster_head[0][1](H).detach().clone()
    #relu
    # H = torch.nn.functional.relu(H)
    H = model.module.cluster_head[0][2](H).detach().clone()

    H_zscore = (H - H.mean(1).reshape(-1, 1)) / H.std(1).reshape(-1, 1)
    #H_zscore = H.detach()
    H_zscore = F.normalize(H_zscore, dim=1)

    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(H_zscore)
    W2 = KMeans_c.cluster_centers_

    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2, cfg['backbone']['nclusters'], use_relu=True)
    """ W1_modi = W1
    W2_modi = W2 """
    
    O = torch.mm(torch.mm(features, W1.T), W2.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))
    O = torch.mm(torch.mm(features, W1_modi.T) , W2_modi.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))

    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    torch.nn.init.zeros_(model.module.classify_tail[0].bias)
    torch.nn.init.zeros_(model.module.classify_tail[3].bias)
    torch.nn.init.zeros_(model.module.classify_medium[0].bias)
    torch.nn.init.zeros_(model.module.classify_medium[3].bias)
    
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()
    model.module.classify_tail[0].weight.data = W1_modi.clone()
    model.module.classify_tail[3].weight.data = W2_modi.clone()
    model.module.classify_medium[0].weight.data = W1_modi.clone()
    model.module.classify_medium[3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()


    predictions = get_predictions(cfg, val_dataloader, model)
    clustering_stats = hungarian_evaluate(cfg, cfg['cdc_checkpoint'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False)
    print(clustering_stats)
    
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.neighbors import NearestNeighbors

def compute_density_weights(features, labels, k=10, alpha=0.5, eps=1e-6):
    """
    计算每个样本的权重，密度高的区域权重低
    features: Tensor [N, D]  (可能在 GPU 上)
    labels: Tensor [N] (可能在 GPU 上)
    """
    device = features.device
    features_cpu = features.detach().cpu().numpy()
    labels_cpu = labels.detach().cpu().numpy()

    weights = np.zeros(len(features_cpu), dtype=np.float32)
    dis = np.zeros(len(features_cpu), dtype=np.float32)

    for c in np.unique(labels_cpu):
        idx = np.where(labels_cpu == c)[0]
        cluster_feats = features_cpu[idx]

        if len(idx) <= k:  # 簇太小，不做密度加权
            weights[idx] = 1.0
            dis[idx] = 1.0
            continue

        nbrs = NearestNeighbors(n_neighbors=k+1, algorithm="auto").fit(cluster_feats)
        distances, _ = nbrs.kneighbors(cluster_feats)
        # 去掉自己本身的距离（第一个是0）
        avg_dist = distances[:, 1:].mean(axis=1)

        #pdb.set_trace()

        cluster_weights = np.exp(alpha * avg_dist)
        cluster_weights = cluster_weights / (cluster_weights.sum() + eps)

        weights[idx] = cluster_weights
        dis[idx] = avg_dist

    #pdb.set_trace()

    """ import matplotlib.pyplot as plt
    dis_norm = (dis - dis.min()) / (dis.max() - dis.min())
    plt.hist(dis_norm, bins=50, density=True, alpha=0.6, color='g', label="Histogram")
    from scipy.stats import gaussian_kde
    kde = gaussian_kde(dis_norm)
    x = np.linspace(0, 1, 200)
    plt.plot(x, kde(x), 'r-', lw=2, label="KDE")
    plt.title("Normalized dis Distribution")
    plt.xlabel("dis (normalized)")
    plt.ylabel("Density")
    plt.legend()
    plt.savefig("num-dis.png") """

    # 转回 torch，并放回原来的 device
    return torch.tensor(weights, dtype=torch.float32, device=device)

def weighted_cluster_centers(features, labels, n_clusters, weights):
    """
    用权重重新计算簇中心
    """
    D = features.size(1)
    centers = torch.zeros((n_clusters, D), device=features.device)
    for c in range(n_clusters):
        idx = (labels == c).nonzero(as_tuple=True)[0]
        #pdb.set_trace()
        if len(idx) > 0:
            w = weights[idx].unsqueeze(1)  # [Nc, 1]
            centers[c] = (features[idx] * w).sum(0) / (w.sum() + 1e-6)
    return centers


def initialize_weights_bias(cfg, model, cali_mlp, features, val_dataloader, k=10, alpha=1.0, target_class= 3):
    # 特征预处理
    features_zscore = (features - features.mean(1, keepdim=True)) / features.std(1, keepdim=True)
    features_zscore = F.normalize(features_zscore, dim=1)

    # Step1: 先KMeans 512
    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False, random_state=0)
    proto_label = KMeans_512.fit_predict(features_zscore)
    proto_label = torch.tensor(proto_label, device=features.device)

    """ # Step2: 用密度加权重新计算 W1
    density_weights = compute_density_weights(features_zscore, proto_label, k=k, alpha=alpha)
    W1 = weighted_cluster_centers(features, proto_label, 512, density_weights) """
    W1 = KMeans_512.cluster_centers_

    # Step3: 通过 cluster_head BN + ReLU
    H = torch.mm(features, W1.T)
    H = model.module.cluster_head[0][1](H).detach().clone()
    H = model.module.cluster_head[0][2](H).detach().clone()

    H_zscore = (H - H.mean(1, keepdim=True)) / H.std(1, keepdim=True)
    H_zscore = F.normalize(H_zscore, dim=1)

    # Step4: KMeans 最终类别数
    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(H_zscore)
    class_label = torch.tensor(class_label, device=features.device)

    # Step5: 用密度加权重新计算 W2
    density_weights2 = compute_density_weights(H_zscore, class_label, k=k, alpha=alpha)
    density_weights2 = density_weights2 * (len(density_weights2) / (density_weights2.sum() + 1e-6))
    W2 = weighted_cluster_centers(H_zscore, class_label, cfg['backbone']['nclusters'], density_weights2)

    #W2 = KMeans_c.cluster_centers_

    density_weights2_np = density_weights2.detach().cpu().numpy()
    percentiles = np.percentile(density_weights2_np, [5,35,65,95])
    indices_per_bin = []
    for i in range(5):
        if i == 0:
            mask = density_weights2_np <= percentiles[i]
        elif i == 4:
            mask = density_weights2_np > percentiles[i-1]
        else:
            mask = (density_weights2_np > percentiles[i-1]) & (density_weights2_np <= percentiles[i])

        # 进一步筛选出 target_class 的样本
        #class_mask = (class_label == target_class).cpu().numpy().astype(bool)
        #mask = mask & class_mask

        bin_indices = np.where(mask)[0]  # 该区间内目标类别的 index
        
        # 随机选取10个 index（如果不足10个就全取）
        #chosen = np.random.choice(bin_indices, size=min(10, len(bin_indices)), replace=False)
        indices_per_bin.append(bin_indices)

    #print(indices_per_bin)
    #pdb.set_trace()

    # Step6: 正交化（你原来的orth_train逻辑）
    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2, cfg['backbone']['nclusters'], use_relu=True)

    # Step7: 写入 cluster_head 和 calibration_head
    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()

    # Step8: 做一次评估
    predictions = get_predictions(cfg, val_dataloader, model)
    clustering_stats = hungarian_evaluate(cfg, cfg['cdc_checkpoint'], 0, 0, predictions,
                                          title=cfg['cluster_eval']['plot_title'],
                                          compute_confusion_matrix=False)
    print(clustering_stats)

    return indices_per_bin

""" def initialize_weights(cfg, model, cali_mlp, features, val_dataloader):
    features_zscore = (features - features.mean(1).reshape(-1, 1)) / features.std(1).reshape(-1, 1)
    features_zscore = F.normalize(features_zscore, dim=1)

    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False)
    proto_label = KMeans_512.fit_predict(features_zscore)
    W1 = KMeans_512.cluster_centers_

    #linear(512,512)
    H = torch.mm(features, W1.T)
    #BN
    H = model.module.cluster_head[0][1](H).detach().clone()
    #relu
    H = model.module.cluster_head[0][2](H).detach().clone()

    H_zscore = (H - H.mean(1).reshape(-1, 1)) / H.std(1).reshape(-1, 1)
    H_zscore = F.normalize(H_zscore, dim=1)

    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False)
    class_label = KMeans_c.fit_predict(H_zscore)
    W2 = KMeans_c.cluster_centers_

    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2, cfg['backbone']['nclusters'], use_relu=True)

    O = torch.mm(torch.mm(features, W1.T), W2.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))
    O = torch.mm(torch.mm(features, W1_modi.T) , W2_modi.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))

    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()


    predictions = get_predictions(cfg, val_dataloader, model)
    clustering_stats = hungarian_evaluate(cfg, cfg['cdc_checkpoint'], 0, 0,
                                        predictions, title=cfg['cluster_eval']['plot_title'],
                                        compute_confusion_matrix=False)
    print(clustering_stats) """
    
def flexrand_select(features, predictions, top_ratio=0.4, gamma=0.3, balanced_per_class=True, seed=42):
    """
    FlexRand采样策略：在 Easy/Hard 区间中分别随机采样子集。
    
    参数:
        features: Tensor [N, D]，特征向量
        predictions: Tensor [N, C]，模型输出（logits）
        top_ratio: float，采样比例（最终总共选出 top_ratio * N 个样本）
        gamma: float，Easy 区间的比例（如 0.3）
        balanced_per_class: bool，是否在每类内使用 FlexRand
        seed: int，随机种子

    返回:
        features_sub: Tensor，采样后的特征子集
        top_indices: Tensor，对应的样本索引
    """
    random.seed(seed)
    probs = torch.softmax(predictions, dim=1)
    confidences, pred_classes = torch.max(probs, dim=1)
    n_clusters = predictions.shape[1]

    top_indices = []

    if balanced_per_class:
        for c in range(n_clusters):
            cls_indices = (pred_classes == c).nonzero(as_tuple=False).squeeze()
            if cls_indices.numel() == 0:
                continue

            cls_conf = confidences[cls_indices]
            sorted_idx = torch.argsort(cls_conf, descending=True)
            sorted_cls_indices = cls_indices[sorted_idx]

            num_total = len(sorted_cls_indices)
            num_select = int(num_total * top_ratio)
            if num_select < 2:
                continue  # 至少要能在两个区间中各取一个

            num_each = num_select // 2
            easy_end = int(num_total * gamma)
            hard_start = int(num_total * (1 - gamma))

            easy_pool = sorted_cls_indices[:easy_end]
            hard_pool = sorted_cls_indices[hard_start:]

            if len(easy_pool) >= num_each:
                sampled_easy = random.sample(easy_pool.tolist(), num_each)
            else:
                sampled_easy = easy_pool.tolist()

            if len(hard_pool) >= num_each:
                sampled_hard = random.sample(hard_pool.tolist(), num_each)
            else:
                sampled_hard = hard_pool.tolist()

            selected = sampled_easy + sampled_hard
            top_indices.extend(selected)
    else:
        sorted_idx = torch.argsort(confidences, descending=True)
        num_total = len(sorted_idx)
        num_select = int(num_total * top_ratio)
        num_each = num_select // 2
        easy_end = int(num_total * gamma)
        hard_start = int(num_total * (1 - gamma))

        easy_pool = sorted_idx[:easy_end]
        hard_pool = sorted_idx[hard_start:]

        sampled_easy = random.sample(easy_pool.tolist(), num_each)
        sampled_hard = random.sample(hard_pool.tolist(), num_each)

        top_indices = sampled_easy + sampled_hard

    top_indices = torch.tensor(top_indices, dtype=torch.long).sort().values
    features_sub = features[top_indices]
    print("FlexRand sample:", len(top_indices))

    return features_sub, top_indices    
    
def select_flexrand_middle_all(
    features: torch.Tensor,
    predictions: torch.Tensor,
    p_low: float = 0.1,
    p_high: float = 0.9,
    easy_ratio: float = 0.3,
    hard_ratio: float = 0.3,
    balanced_per_class: bool = True,
    seed: int = 42
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    FlexRand with full middle + partial easy/hard samples

    参数:
        features: Tensor [N, D]
        predictions: Tensor [N, C]
        p_low: float, e.g., 0.1 (bottom 10% is easy)
        p_high: float, e.g., 0.9 (top 10% is hard)
        easy_ratio: float, ratio to sample from easy region
        hard_ratio: float, ratio to sample from hard region
        balanced_per_class: bool, whether to sample per class
        seed: int, random seed

    返回:
        features_sub: 被选中特征
        selected_indices: 样本索引
    """
    random.seed(seed)
    torch.manual_seed(seed)

    probs = torch.softmax(predictions, dim=1)
    confidences, pred_classes = torch.max(probs, dim=1)
    n_clusters = predictions.shape[1]

    selected_indices = []

    if balanced_per_class:
        for c in range(n_clusters):
            cls_indices = (pred_classes == c).nonzero(as_tuple=False).squeeze()
            if cls_indices.numel() == 0:
                continue

            cls_conf = confidences[cls_indices]
            sorted_idx = torch.argsort(cls_conf, descending=True)
            sorted_cls_indices = cls_indices[sorted_idx]

            n_cls = len(sorted_cls_indices)
            low_idx = int(n_cls * p_low)
            high_idx = int(n_cls * p_high)

            easy_pool = sorted_cls_indices[:low_idx]
            middle_pool = sorted_cls_indices[low_idx:high_idx]
            hard_pool = sorted_cls_indices[high_idx:]

            num_easy = int(len(easy_pool) * easy_ratio)
            num_hard = int(len(hard_pool) * hard_ratio)

            sampled_easy = random.sample(easy_pool.tolist(), min(num_easy, len(easy_pool))) if len(easy_pool) > 0 else []
            sampled_hard = random.sample(hard_pool.tolist(), min(num_hard, len(hard_pool))) if len(hard_pool) > 0 else []

            selected = sampled_easy + middle_pool.tolist() + sampled_hard
            selected_indices.extend(selected)
    else:
        sorted_idx = torch.argsort(confidences, descending=True)
        n_total = len(sorted_idx)
        low_idx = int(n_total * p_low)
        high_idx = int(n_total * p_high)

        easy_pool = sorted_idx[:low_idx]
        middle_pool = sorted_idx[low_idx:high_idx]
        hard_pool = sorted_idx[high_idx:]

        num_easy = int(len(easy_pool) * easy_ratio)
        num_hard = int(len(hard_pool) * hard_ratio)

        sampled_easy = random.sample(easy_pool.tolist(), min(num_easy, len(easy_pool))) if len(easy_pool) > 0 else []
        sampled_hard = random.sample(hard_pool.tolist(), min(num_hard, len(hard_pool))) if len(hard_pool) > 0 else []

        selected_indices = sampled_easy + middle_pool.tolist() + sampled_hard

    selected_indices = torch.tensor(selected_indices, dtype=torch.long).sort().values
    features_sub = features[selected_indices]
    print(f"Selected samples: Easy~{easy_ratio}, Middle~ALL, Hard~{hard_ratio} -> Total: {len(selected_indices)}")

    return features_sub, selected_indices
    
def init_head_with_confident_samples(model, cali_mlp, features, predictions, n_clusters, top_ratio=0.5, confidence_offset=0.0, balanced_per_class=False):
    """
    用高置信度样本做 KMeans 初始化聚类头
    参数:
        model: 模型对象，包含 cluster_head 和 classify_tail
        features: Tensor[N, D]，所有样本特征
        predictions: Tensor[N, C]，每个样本的 softmax logits
        n_clusters: 类别数
        top_ratio: 每类或全局选择的比例
        balanced_per_class: 是否每个类单独选 top 样本（推荐）

    返回:
        None（直接修改 model.cluster_head 和 classify_tail）
    """
    probs = torch.softmax(predictions, dim=1)
    confidences, pred_classes = torch.max(probs, dim=1)

    if balanced_per_class:
        selected_indices = []
        for c in range(n_clusters):
            cls_indices = (pred_classes == c).nonzero(as_tuple=False).squeeze()
            if cls_indices.numel() == 0:
                continue
            cls_confidences = confidences[cls_indices]
            # 排序后选取 offset -> offset + top_k
            sorted_indices = torch.argsort(cls_confidences, descending=True)
            offset = int(len(sorted_indices) * confidence_offset)
            top_k = max(1, int(len(sorted_indices) * top_ratio))
            end = min(offset + top_k, len(sorted_indices))
            selected_cls_idx = cls_indices[sorted_indices[offset:end]]

            selected_indices.append(selected_cls_idx)
        top_indices = torch.cat(selected_indices)
    else:
        sorted_indices = torch.argsort(confidences, descending=True)
        offset = int(len(sorted_indices) * confidence_offset)
        num_top = int(len(sorted_indices) * top_ratio)
        end = min(offset + num_top, len(sorted_indices))
        top_indices = sorted_indices[offset:end]

    # 特征处理
    #features_sub = features[top_indices]
    #features_sub = features
    #top_indices = top_indices.sort().values
    
    """ features_sub, top_indices = select_flexrand_middle_all(
        features=features,
        predictions=predictions,
        p_low=0.1,
        p_high=0.9,
        easy_ratio=0.5,
        hard_ratio=0.5,
        balanced_per_class=True
    )
    """

    features_sub = features[top_indices]
    print("ini sample: ", len(top_indices))
    
    features_sub = (features_sub - features_sub.mean(1, keepdim=True)) / (features_sub.std(1, keepdim=True) + 1e-6)
    #features_sub = (features_sub - features_sub.mean(1).reshape(-1, 1)) / (features_sub.std(1).reshape(-1, 1))
    features_sub = F.normalize(features_sub, dim=1)
    
    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False, random_state=0)
    proto_label = KMeans_512.fit_predict(features_sub)
    W1 = KMeans_512.cluster_centers_

    #linear(512,512)
    H = torch.mm(features, W1.T)
    #H = torch.mm(features_sub, W1.T)
    #BN
    # H = (H - H.mean(0)) / H.std(0)
    H = model.module.cluster_head[0][1](H).detach().clone()
    #relu
    # H = torch.nn.functional.relu(H)
    H = model.module.cluster_head[0][2](H).detach().clone()

    H_zscore = (H - H.mean(1).reshape(-1, 1)) / H.std(1).reshape(-1, 1)
    #H_zscore = H.detach()
    H_zscore = F.normalize(H_zscore, dim=1)

    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=n_clusters, verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(H_zscore)
    W2 = KMeans_c.cluster_centers_

    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2, n_clusters, use_relu=True)
    
    #W1_modi = W1.clone()
    #W2_modi = W2.clone()
    
    O = torch.mm(torch.mm(features, W1.T), W2.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))
    O = torch.mm(torch.mm(features, W1_modi.T) , W2_modi.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))

    with torch.no_grad():
        torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
        torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
        
        model.module.cluster_head[0][0].weight.data = W1_modi.clone()
        model.module.cluster_head[0][3].weight.data = W2_modi.clone()

        torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
        torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
        
        cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
        cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()  
    
def train_cali(cfg, train_dataloader, cali_mlp, model, optimizer_cali, optimizer_all, epoch, start_epoch):
    loss_clu, loss_cali = [],[]
    loss_ces, loss_ens, loss_coss = [],[],[]
    for step, batch in enumerate(train_dataloader):
        model.zero_grad()
        optimizer_all.zero_grad()
        import time
        st = time.time()
        images = batch['image'].cuda(non_blocking=True)
        images_augmented = batch['image_augmented'].cuda(non_blocking=True)
        images_val = batch['val'].cuda(non_blocking=True)
        images_index = batch['index'].cuda(non_blocking=True)
        gt = batch['target'].cuda(non_blocking=True)

        model.train()
        cali_mlp.train()
        with torch.no_grad():
            feature_val = model(images_val, forward_pass='backbone')
            output_clu_val = model(feature_val, forward_pass='head')[0]

            feature_weak = model(images, forward_pass='backbone')
            output_clu = model(feature_weak, forward_pass='head')[0]
            output_cali = cali_mlp(feature_weak, forward_pass='calibration')
        feature_norm1 = F.normalize(feature_val, p=1, dim=1)

        clu_softmax = F.softmax(output_clu, dim=1)
        cali_softmax = F.softmax(output_cali, dim=1)
        clu_prob, clu_label = torch.max(clu_softmax, dim=1)
        cali_prob, cali_label = torch.max(cali_softmax, dim=1)

        proto_pseudo = cali_label
        selected_num = cfg['method_kwargs']['per_class_selected_num']
        # selected_num = int(output_cali.shape[0] / output_cali.shape[1])
        selected_idx = torch.zeros(len(cali_softmax)).cuda()
        for label_idx in range(output_clu.shape[1]):
            per_label_mask = cali_softmax[:, label_idx].sort(descending=True)[1][:selected_num]
            sel = int(cali_prob[per_label_mask].mean() * selected_num)
            selected_idx[per_label_mask[:sel]]=1
        selected_idx = selected_idx==1

        cluster_num = cfg['method_kwargs']['super_cluster_num']
        KMeans_all = PyTorchKMeans(init='k-means++', n_clusters=cluster_num, verbose=False)
        split_all = KMeans_all.fit_predict(feature_norm1)
        target_dict = torch.stack([F.softmax(output_clu_val, dim=1)[split_all == i].mean(0) for i in range(cluster_num)])
        super_target = target_dict[split_all]

        sub_steps = int(cfg['optimizer']['batch_size']/cfg['optimizer']['sub_batch_size'])
        sub_idxs = torch.range(0, sub_steps*cfg['optimizer']['sub_batch_size']-1).to(torch.int64).reshape(sub_steps,-1)
        for sub_step in range(sub_steps):
            sub_idx = sub_idxs[sub_step]
            output_aug = model(images_augmented[sub_idx])[0]
            sub_proto_pseudo, sub_selected_idx = proto_pseudo[sub_idx], selected_idx[sub_idx]
            loss_ce = F.cross_entropy(output_aug[sub_selected_idx], sub_proto_pseudo[sub_selected_idx])
            loss = loss_ce
            loss_ces.append(loss_ce.detach())
            loss_clu.append(loss.detach())

            optimizer_all.zero_grad()
            loss.backward()
            optimizer_all.step()

            output_cali = cali_mlp(feature_val[sub_idx], forward_pass='calibration')
            cali_prob, _ = F.softmax(output_cali, dim=1).max(1)

            loss_cos = (-super_target[sub_idx]*F.log_softmax(output_cali)).sum(1).mean()
            x_ = torch.mean(F.softmax(output_cali, dim=1), 0)
            loss_entropy = torch.sum(x_ * torch.log(x_))

            loss = loss_cos+cfg['method_kwargs']['w_en']*loss_entropy

            loss_cali.append(loss.detach())
            loss_coss.append(loss_cos.detach())
            loss_ens.append(loss_entropy.detach())

            optimizer_cali.zero_grad()
            loss.backward()
            optimizer_cali.step()
    wandb.log({
        "loss_clu":torch.stack(loss_clu).mean(),
        "loss_cali":torch.stack(loss_cali).mean(),
        "loss_ces":torch.stack(loss_ces).mean(),
        "loss_cos":torch.stack(loss_coss).mean(),
        "loss_ens":torch.stack(loss_ens).mean(),
    })

def train_cali_longtail(cfg, train_dataloader, cali_mlp, model, optimizer_cali, optimizer_all, epoch, start_epoch, pseudo_labels, medium_neighbors_idx, tail_neighbors_idx):
    loss_clu, loss_cali = [],[]
    loss_ces, loss_ens, loss_coss = [],[],[]
    epsilon = cfg['epsilon']
    num_classes = cfg['backbone']['nclusters']
    
    for step, batch in enumerate(train_dataloader):
        model.zero_grad()
        optimizer_all.zero_grad()
        import time
        st = time.time()
        images = batch['image'].cuda(non_blocking=True)
        images_augmented = batch['image_augmented'].cuda(non_blocking=True)
        images_val = batch['val'].cuda(non_blocking=True)
        images_index = batch['index'].cuda(non_blocking=True)
        gt = batch['target'].cuda(non_blocking=True)

        model.train()
        cali_mlp.train()
        with torch.no_grad():
            feature_val = model(images_val, forward_pass='backbone')
            output_clu_val = model(feature_val, forward_pass='head')[0]

            feature_weak = model(images, forward_pass='backbone')
            output_clu = model(feature_weak, forward_pass='head')[0]
            #output_tail = model.module.classify_tail(feature_weak)
            #output_medium = model.module.classify_medium(feature_weak)
            output_tail = output_clu
            output_medium = output_clu
            
            output_cali = cali_mlp(feature_weak, forward_pass='calibration')
            
        feature_norm1 = F.normalize(feature_val, p=1, dim=1)

        clu_softmax = F.softmax(output_clu, dim=1)
        cali_softmax = F.softmax(output_cali, dim=1)
        clu_prob, clu_label = torch.max(clu_softmax, dim=1)
        cali_prob, cali_label = torch.max(cali_softmax, dim=1)

        proto_pseudo = cali_label
        selected_num = cfg['method_kwargs']['per_class_selected_num']
        # selected_num = int(output_cali.shape[0] / output_cali.shape[1])
        selected_idx = torch.zeros(len(cali_softmax)).cuda()
        for label_idx in range(output_clu.shape[1]):
            per_label_mask = cali_softmax[:, label_idx].sort(descending=True)[1][:selected_num]
            sel = int(cali_prob[per_label_mask].mean() * selected_num)
            selected_idx[per_label_mask[:sel]]=1
        selected_idx = selected_idx==1

        cluster_num = cfg['method_kwargs']['super_cluster_num']
        KMeans_all = PyTorchKMeans(init='k-means++', n_clusters=cluster_num, verbose=False)
        split_all = KMeans_all.fit_predict(feature_norm1)
        target_dict = torch.stack([F.softmax(output_clu_val, dim=1)[split_all == i].mean(0) for i in range(cluster_num)])
        super_target = target_dict[split_all]

        sub_steps = int(cfg['optimizer']['batch_size']/cfg['optimizer']['sub_batch_size'])
        sub_idxs = torch.range(0, sub_steps*cfg['optimizer']['sub_batch_size']-1).to(torch.int64).reshape(sub_steps,-1)
        
        for sub_step in range(sub_steps):
            sub_idx = sub_idxs[sub_step]
            sub_batch_size = sub_idx.shape[0]
            output_aug = model(images_augmented[sub_idx])[0]
            sub_proto_pseudo, sub_selected_idx = proto_pseudo[sub_idx], selected_idx[sub_idx]
            
            # soft targets for tail / medium expert
            sub_indices = images_index[sub_idx]  # 当前子batch的全局样本 index
            sub_pseudo = sub_proto_pseudo  # 预测伪标签
                # 初始化 one-hot 平滑标签
            target_tail = torch.zeros(len(sub_pseudo), num_classes).cuda()
            target_medium = torch.zeros(len(sub_pseudo), num_classes).cuda()
            
            for j in range(len(sub_pseudo)):
                label = sub_pseudo[j].item()
                target_tail[j, label] = 1 - epsilon
                target_medium[j, label] = 1 - epsilon
                    # 获取该样本的邻居标签
                idx_j = sub_indices[j].item()
                tail_neigh = tail_neighbors_idx[idx_j]
                medium_neigh = medium_neighbors_idx[idx_j]

                if len(tail_neigh) > 0:
                    target_tail[j, tail_neigh] += epsilon / len(tail_neigh)
                if len(medium_neigh) > 0:
                    target_medium[j, medium_neigh] += epsilon / len(medium_neigh)
                
            target_tail = target_tail / target_tail.sum(dim=1, keepdim=True)
            target_medium = target_medium / target_medium.sum(dim=1, keepdim=True)

            spc_dict = Counter(pseudo_labels.cpu().numpy())  # {cluster_id: count}
            n_clusters = cfg['backbone']['nclusters']

            # 构建 spc 列表，并确保最小值为1（避免 log(0)）
            spc = [spc_dict.get(i, 1) for i in range(n_clusters)]
            # 转换为 Tensor 并放到 CUDA
            spc = torch.tensor(spc, dtype=torch.float32).cuda() 
            
            output_tail_sub = output_tail[sub_idx]
            output_medium_sub = output_medium[sub_idx]
            output_tail_sel = output_tail_sub[sub_selected_idx]
            output_medium_sel = output_medium_sub[sub_selected_idx]
            target_tail_sel = target_tail[sub_selected_idx]
            target_medium_sel = target_medium[sub_selected_idx]
            adj_tail = output_tail_sel + 1.0 * spc.log()
            adj_medium = output_medium_sel + 0.5 * spc.log()


            loss_tail = -torch.sum(F.log_softmax(adj_tail, dim=1) * target_tail_sel) /  target_tail_sel.shape[0]
            loss_medium = -torch.sum(F.log_softmax(adj_medium, dim=1) * target_medium_sel) /  target_medium_sel.shape[0]
            
            loss_ce = F.cross_entropy(output_aug[sub_selected_idx], sub_proto_pseudo[sub_selected_idx])
            loss = loss_ce +  loss_tail +  loss_medium
            loss_ces.append(loss_ce.detach())
            loss_clu.append(loss.detach())

            optimizer_all.zero_grad()
            loss.backward()
            optimizer_all.step()

            output_cali = cali_mlp(feature_val[sub_idx], forward_pass='calibration')
            cali_prob, _ = F.softmax(output_cali, dim=1).max(1)

            loss_cos = (-super_target[sub_idx]*F.log_softmax(output_cali)).sum(1).mean()
            x_ = torch.mean(F.softmax(output_cali, dim=1), 0)
            loss_entropy = torch.sum(x_ * torch.log(x_))

            loss = loss_cos+cfg['method_kwargs']['w_en']*loss_entropy

            loss_cali.append(loss.detach())
            loss_coss.append(loss_cos.detach())
            loss_ens.append(loss_entropy.detach())

            optimizer_cali.zero_grad()
            loss.backward()
            optimizer_cali.step()
    wandb.log({
        "loss_clu":torch.stack(loss_clu).mean(),
        "loss_cali":torch.stack(loss_cali).mean(),
        "loss_ces":torch.stack(loss_ces).mean(),
        "loss_cos":torch.stack(loss_coss).mean(),
        "loss_ens":torch.stack(loss_ens).mean(),
    })
       
def initialize_weights_v4(cfg, model, cali_mlp, features, top_k_percent=0.5):
    print('Initializing weights V4 with high confidence samples...')
    features_zscore = (features - features.mean(1).reshape(-1, 1)) / features.std(1).reshape(-1, 1)
    features_zscore = F.normalize(features_zscore, dim=1)

    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False, random_state=0)
    proto_label = KMeans_512.fit_predict(features_zscore)
    W1_initial = KMeans_512.cluster_centers_

    H = torch.mm(features, W1_initial.T)
    H = model.module.cluster_head[0][1](H).detach().clone() # BN
    H = model.module.cluster_head[0][2](H).detach().clone() # ReLU
    
    H_zscore = (H - H.mean(1).reshape(-1, 1)) / H.std(1).reshape(-1, 1)
    H_zscore = F.normalize(H_zscore, dim=1)
    
    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(H_zscore)
    W2_initial = KMeans_c.cluster_centers_
    
    # 计算第一阶段的样本到其所属原型簇中心的距离平方
    distances_to_proto_centers = torch.sum((features_zscore - W1_initial[proto_label])**2, dim=1)
    
    # 计算第二阶段的样本到其所属最终簇中心的距离平方
    distances_to_class_centers = torch.sum((H_zscore - W2_initial[class_label])**2, dim=1)
    
    high_conf_indices_1 = []
    for i in range(512):
        # 获取当前原型簇的所有样本索引
        cluster_samples_indices_tuple = (proto_label == i).nonzero(as_tuple=True)[0]
        # 修正：检查元组中的张量是否为空
        if len(cluster_samples_indices_tuple) == 0:
            continue

        # 修正：从元组中取出实际的索引张量
        current_cluster_indices = cluster_samples_indices_tuple 

        # 获取这些样本的距离
        cluster_distances = distances_to_proto_centers[current_cluster_indices]

        # 找到距离最小的 top_k_percent 样本的索引
        num_to_select = max(1, int(len(current_cluster_indices) * top_k_percent)) # 使用张量的长度
        sorted_indices = torch.argsort(cluster_distances)[:num_to_select]

        # 修正：正确索引张量
        high_conf_indices_1.extend(current_cluster_indices[sorted_indices].tolist())

    high_conf_indices_1 = torch.tensor(high_conf_indices_1).cuda()
    
    high_conf_indices_2 =[]
    for i in range(cfg['backbone']['nclusters']):
        # 获取当前最终簇的所有样本索引
        cluster_samples_indices_tuple = (class_label == i).nonzero(as_tuple=True)[0]
        # 修正：检查元组中的张量是否为空
        if len(cluster_samples_indices_tuple) == 0:
            continue

        # 修正：从元组中取出实际的索引张量
        current_cluster_indices = cluster_samples_indices_tuple

        # 获取这些样本的距离
        cluster_distances = distances_to_class_centers[current_cluster_indices]

        # 找到距离最小的 top_k_percent 样本的索引
        num_to_select = max(1, int(len(current_cluster_indices) * top_k_percent)) # 使用张量的长度
        sorted_indices = torch.argsort(cluster_distances)[:num_to_select]

        # 修正：正确索引张量
        high_conf_indices_2.extend(current_cluster_indices[sorted_indices].tolist())

    high_conf_indices_2 = torch.tensor(high_conf_indices_2).cuda()
    
    W1_modi = torch.zeros_like(W1_initial)
    for i in range(512):
        # 筛选出属于当前簇且在高置信度列表中的样本
        current_cluster_high_conf_samples = features_zscore[
            (proto_label == i) & (torch.isin(torch.arange(len(features_zscore), device=features.device), high_conf_indices_1))
        ]
        if len(current_cluster_high_conf_samples) > 0:
            W1_modi[i] = current_cluster_high_conf_samples.mean(dim=0)
        else:
            W1_modi[i] = W1_initial[i] # 兜底：如果筛选后无样本，则使用原始中心
            
    W2_modi = torch.zeros_like(W2_initial)
    for i in range(cfg['backbone']['nclusters']):
        # 筛选出属于当前簇且在高置信度列表中的样本
        current_cluster_high_conf_samples = H_zscore[
            (class_label == i) & (torch.isin(torch.arange(len(H_zscore), device=H_zscore.device), high_conf_indices_2))
        ]
        if len(current_cluster_high_conf_samples) > 0:
            W2_modi[i] = current_cluster_high_conf_samples.mean(dim=0)
        else:
            W2_modi[i] = W2_initial[i] # 兜底：如果筛选后无样本，则使用原始中心
            
    W1_modi = orth_train(W1_modi, 512, use_relu=True)
    W2_modi = orth_train(W2_modi, cfg['backbone']['nclusters'], use_relu=True)

    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    torch.nn.init.zeros_(model.module.classify_tail[0].bias)
    torch.nn.init.zeros_(model.module.classify_tail[3].bias)
    torch.nn.init.zeros_(model.module.classify_medium[0].bias)
    torch.nn.init.zeros_(model.module.classify_medium[3].bias)
    
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()
    model.module.classify_tail[0].weight.data = W1_modi.clone()
    model.module.classify_tail[3].weight.data = W2_modi.clone()
    model.module.classify_medium[0].weight.data = W1_modi.clone()
    model.module.classify_medium[3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()

# 分层聚类
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist

def hierarchical_merge(W1, target_clusters):
    """
    使用层次聚类将 512 个聚类中心合并为 target_clusters 个
    """
    # [512, 512] => numpy
    W1_np = W1.detach().cpu().numpy()

    # 计算 pairwise 距离（默认欧氏）
    dist = pdist(W1_np, metric='euclidean')

    # 使用ward方法生成层次聚类树
    linkage_matrix = linkage(dist, method='ward')

    # 剪枝出 target_clusters 个簇
    labels = fcluster(linkage_matrix, target_clusters, criterion='maxclust')

    # 聚合为新中心
    new_centers = []
    for i in range(1, target_clusters + 1):
        idx = torch.tensor(labels == i)
        center = W1[idx].mean(dim=0)
        new_centers.append(center)

    return torch.stack(new_centers, dim=0)  # [K, D]

def initialize_weights_v5(cfg, model, cali_mlp, features, val_dataloader):
    features_zscore = (features - features.mean(1).reshape(-1, 1)) / features.std(1).reshape(-1, 1)
    features_zscore = F.normalize(features_zscore, dim=1)

    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False, random_state=0)
    proto_label = KMeans_512.fit_predict(features_zscore)
    W1 = KMeans_512.cluster_centers_
    
    W1_orth = torch.empty_like(W1)
    torch.nn.init.orthogonal_(W1_orth)
    W1_matched = match_clusters_hungarian(W1_orth, W1)
    W1= W1_matched.clone()

    #linear(512,512)
    H = torch.mm(features, W1.T)
    #BN
    H = model.module.cluster_head[0][1](H).detach().clone()
    #relu
    H = model.module.cluster_head[0][2](H).detach().clone()

    H_zscore = (H - H.mean(1).reshape(-1, 1)) / H.std(1).reshape(-1, 1)

    H_zscore = F.normalize(H_zscore, dim=1)

    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(H_zscore)
    W2 = KMeans_c.cluster_centers_
    
    #W2 = hierarchical_merge(W1, cfg['backbone']['nclusters'])

    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2, cfg['backbone']['nclusters'], use_relu=True)

    O = torch.mm(torch.mm(features, W1.T), W2.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))
    O = torch.mm(torch.mm(features, W1_modi.T) , W2_modi.T)
    print(F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))

    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()

from scipy.optimize import linear_sum_assignment

def match_clusters_hungarian(W_orthogonal: torch.Tensor, W_kmeans: torch.Tensor) -> torch.Tensor:
    """
    使用匈牙利算法将正交初始化的权重 W_orthogonal 与 KMeans 聚类中心 W_kmeans 进行最优一一匹配。

    参数：
    - W_orthogonal: shape (n_clusters, feature_dim)，正交初始化矩阵
    - W_kmeans: shape (n_clusters, feature_dim)，KMeans 聚类中心

    返回：
    - W_matched: shape (n_clusters, feature_dim)，按匹配顺序排列的 W_orthogonal
    """

    # 确保维度一致
    assert W_orthogonal.shape == W_kmeans.shape
    n_clusters = W_orthogonal.shape[0]

    # 归一化向量（在余弦空间中进行匹配）
    W_orth = F.normalize(W_orthogonal, dim=1)  # (n, d)
    W_km = F.normalize(W_kmeans, dim=1)        # (n, d)

    # 计算余弦相似度矩阵（越大越相似）
    sim_matrix = torch.matmul(W_km, W_orth.T).cpu().numpy()  # (n, n)

    # 匈牙利算法要求是“成本矩阵”，所以取相反数（负相似度 = 正成本）
    cost_matrix = -sim_matrix

    # 求解最小代价匹配
    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    # 按 col_ind 索引 W_orth 中的行（即重排序）
    W_matched = W_orth[col_ind]

    return W_matched


# longtail
""" from sklearn.metrics import pairwise_distances
    
def initialize_weights_longtail_v1(cfg, probs, model, cali_mlp, features):
    model.eval()
    cali_mlp.eval()
    features_np = features.detach().cpu().numpy()
    probs_np = probs.detach().cpu().numpy()
    pred_classes = np.argmax(probs_np, axis=1)
    n_classes = probs_np.shape[1]
    
    class_centers = []
    class_counts = []
    all_class_features = []

    # 1. 计算每类的原始类中心和样本数量
    for i in range(n_classes):
        mask = pred_classes == i
        class_features = features_np[mask]
        if len(class_features) == 0:
            center = np.zeros(features_np.shape[1])
        else:
            center = np.mean(class_features, axis=0)
        class_centers.append(center)
        class_counts.append(len(class_features))
        all_class_features.append(class_features)

    class_centers = np.array(class_centers)
    class_counts = np.array(class_counts)
    minority_threshold = np.percentile(class_counts, 30)  # 少数类阈值（低于30%分位）
    minority_classes = np.where(class_counts <= minority_threshold)[0]

    print(f"[Minority classes]: {minority_classes}")
    
    # 2. 对每个少数类做软样本合成
    for a in minority_classes:
        center_a = class_centers[a]
        dists = np.linalg.norm(class_centers - center_a, axis=1)
        neighbor_ids = np.argsort(dists)[1:4]  # 最邻近的3个簇（不包括自己）

        aggregated_features = [all_class_features[a]]  # 原始a类样本
        aggregated_weights = [np.ones(len(all_class_features[a]))]  # 权重为1

        d_avg = np.mean(np.linalg.norm(all_class_features[a] - center_a, axis=1)) + 1e-6

        for b in neighbor_ids:
            features_b = all_class_features[b]
            dist_b = np.linalg.norm(features_b - center_a, axis=1)
            top_k = int(0.0001 * len(dist_b))  # 只选最近10%
            topk_idx = np.argsort(dist_b)[:top_k]
            selected_feats = features_b[topk_idx]
            selected_dists = dist_b[topk_idx]

            weights = np.exp(-selected_dists / d_avg)  # 距离越近，权重越高
            aggregated_features.append(selected_feats)
            aggregated_weights.append(weights)

        agg_feats = np.vstack(aggregated_features)
        agg_weights = np.concatenate(aggregated_weights)
        agg_weights = agg_weights / np.sum(agg_weights)  # normalize

        new_center_a = np.average(agg_feats, axis=0, weights=agg_weights)
        class_centers[a] = new_center_a  # 更新类中心

    
    # 3. 赋值初始化
    class_centers = torch.tensor(class_centers, dtype=torch.float32).to(features.device)
    
    #cali_mlp.fc.weight.data.copy_(class_centers)
    cali_mlp.module.calibration_head[3].weight.data = class_centers.clone()
    model.module.cluster_head[0][3].weight.data = class_centers.clone() """

import torch
import numpy as np
from cdc.utils.torch_clustering import PyTorchKMeans


def initialize_weights_longtail_v1(cfg, model, cali_mlp, features, val_dataloader, neighbor_k=5, top_ratio=0.1, ep=0.2):
    # 1. 特征标准化
    features_zscore = (features - features.mean(1, keepdim=True)) / features.std(1, keepdim=True)
    features_zscore = F.normalize(features_zscore, dim=1)

    # 2. 第一阶段聚类：KMeans(512)
    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False)
    proto_label = KMeans_512.fit_predict(features_zscore)
    W1 = KMeans_512.cluster_centers_

    # 3. linear -> BN -> ReLU
    H = torch.mm(features, W1.T)
    H = model.module.cluster_head[0][1](H).detach().clone()
    H = model.module.cluster_head[0][2](H).detach().clone()

    # 4. 第二阶段聚类：KMeans(nclusters)
    H_zscore = (H - H.mean(1, keepdim=True)) / H.std(1, keepdim=True)
    H_zscore = F.normalize(H_zscore, dim=1)
    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False)
    class_label = KMeans_c.fit_predict(H_zscore)
    W2 = KMeans_c.cluster_centers_  # shape: [C, D]

    # 5. 找出每个簇的样本索引
    cluster_to_indices = {i: (class_label == i).nonzero(as_tuple=True)[0] for i in range(cfg['backbone']['nclusters'])}
    cluster_sizes = {i: len(v) for i, v in cluster_to_indices.items()}

    # Step A: 计算动态阈值
    avg_cluster_size = sum(cluster_sizes.values()) / len(cluster_sizes)
    dynamic_threshold = avg_cluster_size*1

    print(f"平均簇大小: {avg_cluster_size:.2f}, 动态小簇阈值: {dynamic_threshold:.2f}")    
    
    # 6. 找出小簇，重计算中心
    W2_refined = []
    for i in range(cfg['backbone']['nclusters']):
        indices_i = cluster_to_indices[i]
        if len(indices_i) >= dynamic_threshold:
            # 维持原中心
            W2_refined.append(W2[i])
            continue

        original_samples = H_zscore[indices_i] 
        center_i = W2[i].unsqueeze(0)  # [1, D]
        
        # 与其他簇中心的距离
        #other_centers = torch.cat([W2[j].unsqueeze(0) for j in range(cfg['backbone']['nclusters']) if j != i], dim=0)
        distances = F.cosine_similarity(center_i, W2)  # [C-1]
        distances[i] = float('-inf')
        neighbor_k = min(neighbor_k, len(distances))  # 确保不超过实际簇数
        topk_indices = torch.topk(distances, neighbor_k, largest=True).indices

        # 收集这些邻居簇的样本
        neighbor_samples = []
        for k in topk_indices:
            actual_k = k.item() #if k.item() < i else k.item() + 1
            neighbor_samples.append(H_zscore[cluster_to_indices[actual_k]])
        neighbor_samples = torch.cat(neighbor_samples, dim=0)  # [N', D]
        
        #import pdb; pdb.set_trace()

        # 计算每个样本到 a 簇中心的距离
        dists = F.cosine_similarity(neighbor_samples, center_i)  # [N', 1]
        dists = 1 - dists
        topk = int(len(dists) * top_ratio)
        topk_indices = torch.topk(dists, topk, largest=False).indices

        selected_samples = neighbor_samples[topk_indices]

        dists_i = F.cosine_similarity(original_samples, center_i)
        dists_i = 1 - dists_i
        dist_avg = dists_i.mean()
        
        sample_weights = dist_avg / (dists[topk_indices]) *ep 

        #import pdb; pdb.set_trace()
        
        # 原簇 a 样本的权重统一为 1（归一化前）
        cluster_a_weights = torch.ones(len(original_samples)).to(features.device)

        # 拼接所有样本和权重
        all_samples = torch.cat([original_samples, selected_samples], dim=0)
        all_weights = torch.cat([cluster_a_weights, sample_weights], dim=0)

        # 归一化总权重
        all_weights = all_weights / all_weights.sum()

        # 加权融合形成 refined 类中心
        new_center = (all_samples * all_weights.unsqueeze(1)).sum(dim=0)
        W2_refined.append(new_center)

    W2_refined = torch.stack(W2_refined, dim=0)

    # 7. 正交化权重
    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2_refined, cfg['backbone']['nclusters'], use_relu=True)

    # 8. 打印初始化结果分布
    O = torch.mm(torch.mm(features, W1.T), W2.T)
    print("原簇中心分布：", F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))

    O = torch.mm(torch.mm(features, W1_modi.T), W2_modi.T)
    print("优化后中心分布：", F.softmax(O, dim=1).max(1)[1].unique(return_counts=True))

    # 9. 初始化模型参数
    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()


import torch
import torch.nn.functional as F
import numpy as np
from sklearn.neighbors import NearestNeighbors

def compute_density_weights(features, labels, k=10, alpha=1.0, eps=1e-6):
    """
    计算每个样本的权重，密度高的区域权重低
    features: Tensor [N, D]  (可能在 GPU 上)
    labels: Tensor [N] (可能在 GPU 上)
    """
    device = features.device
    features_cpu = features.detach().cpu().numpy()
    labels_cpu = labels.detach().cpu().numpy()

    weights = np.zeros(len(features_cpu), dtype=np.float32)

    for c in np.unique(labels_cpu):
        idx = np.where(labels_cpu == c)[0]
        cluster_feats = features_cpu[idx]

        if len(idx) <= k:  # 簇太小，不做密度加权
            weights[idx] = 1.0
            continue

        nbrs = NearestNeighbors(n_neighbors=k+1, algorithm="auto").fit(cluster_feats)
        distances, _ = nbrs.kneighbors(cluster_feats)
        # 去掉自己本身的距离（第一个是0）
        avg_dist = distances[:, 1:].mean(axis=1)

        #pdb.set_trace()

        # 权重：距离越小密度越大 → 权重越低
        cluster_weights = np.exp(alpha * avg_dist)
        cluster_weights = cluster_weights / (cluster_weights.sum() + eps)

        weights[idx] = cluster_weights

    # 转回 torch，并放回原来的 device
    return torch.tensor(weights, dtype=torch.float32, device=device)

def weighted_cluster_centers(features, labels, n_clusters, weights):
    """
    用权重重新计算簇中心
    """
    D = features.size(1)
    centers = torch.zeros((n_clusters, D), device=features.device)
    for c in range(n_clusters):
        idx = (labels == c).nonzero(as_tuple=True)[0]
        #pdb.set_trace()
        if len(idx) > 0:
            w = weights[idx].unsqueeze(1)  # [Nc, 1]
            centers[c] = (features[idx] * w).sum(0) / (w.sum() + 1e-6)
    return centers


def initialize_weights_bias(cfg, model, cali_mlp, features, val_dataloader, k=10, alpha=0.5):
    # 特征预处理
    features_zscore = (features - features.mean(1, keepdim=True)) / features.std(1, keepdim=True)
    features_zscore = F.normalize(features_zscore, dim=1)

    # Step1: 先KMeans 512
    KMeans_512 = PyTorchKMeans(init='k-means++', n_clusters=512, verbose=False, random_state=0)
    proto_label = KMeans_512.fit_predict(features_zscore)
    proto_label = torch.tensor(proto_label, device=features.device)

    """ # Step2: 用密度加权重新计算 W1
    density_weights = compute_density_weights(features_zscore, proto_label, k=k, alpha=alpha)
    W1 = weighted_cluster_centers(features, proto_label, 512, density_weights) """
    W1 = KMeans_512.cluster_centers_

    # Step3: 通过 cluster_head BN + ReLU
    H = torch.mm(features, W1.T)
    H = model.module.cluster_head[0][1](H).detach().clone()
    H = model.module.cluster_head[0][2](H).detach().clone()

    H_zscore = (H - H.mean(1, keepdim=True)) / H.std(1, keepdim=True)
    H_zscore = F.normalize(H_zscore, dim=1)

    # Step4: KMeans 最终类别数
    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(H_zscore)
    class_label = torch.tensor(class_label, device=features.device)

    # Step5: 用密度加权重新计算 W2
    density_weights2 = compute_density_weights(H_zscore, class_label, k=k, alpha=alpha)
    #pdb.set_trace()
    density_weights2 = density_weights2 * (len(density_weights2) / (density_weights2.sum() + 1e-6))
    W2 = weighted_cluster_centers(H_zscore, class_label, cfg['backbone']['nclusters'], density_weights2)
    #print(W2[0])
    #W2 = KMeans_c.cluster_centers_

    #pdb.set_trace()

    # Step6: 正交化（你原来的orth_train逻辑）
    W1_modi = orth_train(W1, 512, use_relu=True)
    W2_modi = orth_train(W2, cfg['backbone']['nclusters'], use_relu=True)

    # Step7: 写入 cluster_head 和 calibration_head
    torch.nn.init.zeros_(model.module.cluster_head[0][0].bias)
    torch.nn.init.zeros_(model.module.cluster_head[0][3].bias)
    model.module.cluster_head[0][0].weight.data = W1_modi.clone()
    model.module.cluster_head[0][3].weight.data = W2_modi.clone()

    torch.nn.init.zeros_(cali_mlp.module.calibration_head[0].bias)
    torch.nn.init.zeros_(cali_mlp.module.calibration_head[3].bias)
    cali_mlp.module.calibration_head[0].weight.data = W1_modi.clone()
    cali_mlp.module.calibration_head[3].weight.data = W2_modi.clone()

    # Step8: 做一次评估
    predictions = get_predictions(cfg, val_dataloader, model)
    clustering_stats = hungarian_evaluate(cfg, cfg['cdc_checkpoint'], 0, 0, predictions,
                                          title=cfg['cluster_eval']['plot_title'],
                                          compute_confusion_matrix=False)
    print(clustering_stats)

    return clustering_stats