
import torch

from cdc.losses.losses import SCANLoss, ConfidenceBasedCE
import torch.nn.functional as F
from collections import Counter
from torch.utils.data import Subset, DataLoader
import time

def freeze_backbone(model):
    """冻结 model.module.backbone 的所有参数"""
    for param in model.module.backbone.parameters():
        param.requires_grad = False
    print("[INFO] Backbone frozen.")

def select_high_confidence_samples(model, dataloader, top_ratio=0.1):
    """
    返回高置信度样本的 index 列表和其对应的伪标签。

    返回：
        top_sample_indices (Tensor): 高置信度样本的全局索引
        top_pseudo_labels (Tensor): 对应伪标签
    """
    model.eval()
    all_indices, all_confidences, all_pseudo_labels = [], [], []
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch[0]
            indices = batch[2]
            outputs = model(images.cuda(non_blocking=True),
                            forward_pass='return_all')['output'][0]
            probs = F.softmax(outputs, dim=1)
            confs, preds = torch.max(probs, dim=1)

            all_indices.append(indices)
            all_confidences.append(confs.cpu())
            all_pseudo_labels.append(preds.cpu())

    all_indices = torch.cat(all_indices)
    all_confidences = torch.cat(all_confidences)
    all_pseudo_labels = torch.cat(all_pseudo_labels)

    """ num_top = int(top_ratio * len(all_confidences))
    top_indices = torch.argsort(all_confidences, descending=True)[:num_top] """
    
    threshold = 0.3  # 或你想设定的其他值
    top_mask = all_confidences > threshold
    top_indices = torch.nonzero(top_mask, as_tuple=False).squeeze()
    print(f"[Confidence Filter] Selected {len(top_indices)} samples with confidence > {threshold}")

    if len(top_indices)==0:
        num_top = int(top_ratio * len(all_confidences))
        top_indices = torch.argsort(all_confidences, descending=True)[:num_top]

    top_sample_indices = all_indices[top_indices]
    top_pseudo_labels = all_pseudo_labels[top_indices]

    return top_sample_indices, top_pseudo_labels

def train_with_pseudo_labels(cfg, model, dataset, optimizer, sample_indices, pseudo_labels):
    """
    使用高置信度伪标签样本进行一次小批次训练（仅1轮）。

    参数：
    - dataset: 原始训练集（非 dataloader）
    - sample_indices: 高置信度样本索引（Tensor）
    - pseudo_labels: 与索引对应的标签（Tensor）
    """
    # 建立索引 -> 标签映射表
    label_map = {idx.item(): label.item() for idx, label in zip(sample_indices, pseudo_labels)}

    subset = Subset(dataset, sample_indices)

    def pseudo_label_collate_fn(batch):
        images = torch.stack([item['image'] for item in batch])
        indices = [item['index'].item() for item in batch]
        targets = torch.tensor([label_map[idx] for idx in indices])
        return images, targets

    loader = DataLoader(subset,
                        batch_size=cfg['optimizer']['batch_size'],
                        shuffle=True,
                        num_workers=cfg.get('num_workers', 4),
                        pin_memory=True,
                        collate_fn=pseudo_label_collate_fn)

    model.train()
    
    freeze_backbone(model)

    total_loss = 0
    for images, targets in loader:
        images = images.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)[0]  # 主头输出
        loss = F.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)

    avg_loss = total_loss / len(loader.dataset)
    print(f"[Pseudo-label Training] Top {len(loader.dataset)} samples, avg loss: {avg_loss:.4f}")

def scan_train(cfg, clustering_stats, train_loader, model, criterion, optimizer, pseudo_labels, update_cluster_head_only=False, ce = False):
    """
    Train w/ SCAN-Loss
    """
    model.train()  # Update BN
    num_class = cfg['backbone']['nclusters']
    epoch_start = time.time()   # 开始计时
    for i, batch in enumerate(train_loader):
        # Forward pass
        anchors = batch['image'].cuda(non_blocking=True)
        neighbors = batch['neighbor'].cuda(non_blocking=True)
        indices = batch['index'].cuda()

        anchors_output = model(anchors)
        neighbors_output = model(neighbors)

        # Loss for every head
        total_loss, consistency_loss, entropy_loss = [], [], []
        for anchors_output_subhead, neighbors_output_subhead in zip(anchors_output, neighbors_output):
            total_loss_, consistency_loss_, entropy_loss_ = SCANLoss(cfg['method_kwargs']['entropy_weight'])(anchors_output_subhead,
                                                                      neighbors_output_subhead)
            total_loss.append(total_loss_)
            consistency_loss.append(consistency_loss_)
            entropy_loss.append(entropy_loss_)

        total_loss = torch.sum(torch.stack(total_loss, dim=0))
        loss = total_loss
        
        #import pdb; pdb.set_trace()

        if ce:
            batch_size = anchors.shape[0]
            pseudo = torch.tensor(pseudo_labels)[indices.cpu()].long().cuda()
            pseudo_onehot = torch.zeros(batch_size, num_class).cuda()
            pseudo_onehot.scatter_(1, pseudo.view(-1,1), 1)
            target = torch.zeros_like(pseudo_onehot)
            #import pdb;pdb.set_trace()
            loss_ce = -torch.sum(F.log_softmax(anchors_output[0], dim=1) * target) / batch_size
            loss = total_loss + loss_ce
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_time = time.time() - epoch_start
    return epoch_time


from cdc.methods.dyn_train import SampleMasterTracker
def scan_train_sample(cfg, train_loader, model, criterion, optimizer, pseudo_labels, tracker:SampleMasterTracker, ce = False, stabilityloss=False):
    """
    Train w/ SCAN-Loss
    """
    model.train()  # Update BN
    num_class = cfg['backbone']['nclusters']

    for i, batch in enumerate(train_loader):
        # Forward pass
        anchors = batch['image'].cuda(non_blocking=True)
        anchors_strong = batch['image_augmented'].cuda(non_blocking=True)
        neighbors = batch['neighbor'].cuda(non_blocking=True)
        indices = batch['index'].cuda()

        anchors_output = model(anchors)
        anchors_strong_output = model(anchors_strong)
        neighbors_output = model(neighbors)

        #import pdb;pdb.set_trace()
        feature_stability = F.cosine_similarity(anchors_output[0], anchors_strong_output[0], dim=1)
        stability_loss = 1 - feature_stability

        # Loss for every head
        total_loss, consistency_loss, entropy_loss = [], [], []
        for anchors_output_subhead, neighbors_output_subhead in zip(anchors_output, neighbors_output):
            total_loss_, consistency_loss_, entropy_loss_ = SCANLoss(cfg['method_kwargs']['entropy_weight'])(anchors_output_subhead,
                                                                      neighbors_output_subhead)
            total_loss.append(total_loss_)
            consistency_loss.append(consistency_loss_)
            entropy_loss.append(entropy_loss_)

        total_loss = torch.sum(torch.stack(total_loss, dim=0))
        loss = total_loss
        
        #import pdb; pdb.set_trace()

        if ce:
            batch_size = anchors.shape[0]
            pseudo = torch.tensor(pseudo_labels)[indices.cpu()].long().cuda()
            pseudo_onehot = torch.zeros(batch_size, num_class).cuda()
            pseudo_onehot.scatter_(1, pseudo.view(-1,1), 1)
            target = torch.zeros_like(pseudo_onehot)
            #import pdb;pdb.set_trace()
            loss_ce = -torch.sum(F.log_softmax(anchors_output[0], dim=1) * target) / batch_size
            loss = total_loss + loss_ce

        if stabilityloss:
            prob = F.softmax(anchors_output[0], dim=1)   # [batch_size, num_class]
            conf, label = torch.max(prob, dim=1)
            tracker.update(
                indices=indices.tolist(),
                confidences=prob.tolist(),
                labels=label.tolist(),
                losses=stability_loss.tolist()
            )
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    tracker.step()
    print(f"Removed={len(tracker.removed)}, restored={len(tracker.restore_log)}")


def selflabel_train(cfg, clustering_stats, train_loader, model, criterion, optimizer, epoch, ema=None):
    """
        Self-labeling based on confident samples
        """
    model.train()
    for i, batch in enumerate(train_loader):
        images = batch['image'].cuda(non_blocking=True)
        images_augmented = batch['image_augmented'].cuda(non_blocking=True)
        gt = batch['target'].cuda(non_blocking=True)

        if len(clustering_stats) != 0:
            gt_map = clustering_stats['hungarian_match']
            for pre, post in gt_map:
                gt[batch['target'] == post] = pre

        with torch.no_grad():
            output = model(images)[0]
        output_augmented = model(images_augmented)[0]

        loss = ConfidenceBasedCE(cfg['method_kwargs']['threshold'],
                                cfg['method_kwargs']['apply_class_balancing'])(output, output_augmented)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def scan_train_longtail(cfg, train_loader, model, optimizer, epoch,
                        tail_neighbors_idx, medium_neighbors_idx, pseudo_labels):
    """
    SCAN training with long-tail aware expert branches and soft neighbor-enhanced targets.
    """
    model.train()

    epsilon = cfg['epsilon']
    entropy_weight = cfg['method_kwargs']['entropy_weight']
    num_class = cfg['backbone']['nclusters']

    for i, batch in enumerate(train_loader):
        anchors = batch['image'].cuda(non_blocking=True)
        neighbors = batch['neighbor'].cuda(non_blocking=True)
        indices = batch['index'].cuda()
        
        anchors_res = model(anchors, forward_pass='return_all')
        anchors_feats, anchors_outputs = anchors_res['features'], anchors_res['output']
        neighbors_res = model(neighbors, forward_pass='return_all')
        neighbors_outputs = neighbors_res['output']

        total_loss = 0
        for out_a, out_n in zip(anchors_outputs, neighbors_outputs):
            scan_loss, cons_loss, ent_loss = SCANLoss(entropy_weight)(out_a, out_n)
            total_loss += scan_loss

        #### Step 1: 伪标签 one-hot
        batch_size = anchors.shape[0]
        pseudo = torch.tensor(pseudo_labels)[indices.cpu()].long().cuda()
        pseudo_onehot = torch.zeros(batch_size, num_class).cuda()
        pseudo_onehot.scatter_(1, pseudo.view(-1,1), 1)

        #### Step 2: 构建 soft targets for tail / medium expert
        target_tail = torch.zeros_like(pseudo_onehot)
        target_medium = torch.zeros_like(pseudo_onehot)

        for j in range(batch_size):
            label = pseudo[j].item()
            # Base confidence
            target_tail[j, label] = 1 - epsilon
            target_medium[j, label] = 1 - epsilon

            # neighbor-enhanced confidence
            tail_neigh = tail_neighbors_idx[indices[j].item()]
            medium_neigh = medium_neighbors_idx[indices[j].item()]

            if len(tail_neigh) > 0:
                target_tail[j, tail_neigh] += epsilon / len(tail_neigh)
            if len(medium_neigh) > 0:
                target_medium[j, medium_neigh] += epsilon / len(medium_neigh)

        target_tail = target_tail / target_tail.sum(dim=1, keepdim=True)
        target_medium = target_medium / target_medium.sum(dim=1, keepdim=True)

        #### Step 3: 从专家 head 获取输出
        logits_tail = model.module.classify_tail(anchors_feats)
        logits_medium = model.module.classify_medium(anchors_feats)

        spc_dict = Counter(pseudo_labels.cpu().numpy())  # {cluster_id: count}
        n_clusters = cfg['backbone']['nclusters']

        # 构建 spc 列表，并确保最小值为1（避免 log(0)）
        spc = [spc_dict.get(i, 1) for i in range(n_clusters)]

        # 转换为 Tensor 并放到 CUDA
        spc = torch.tensor(spc, dtype=torch.float32).cuda()
        
        """ adj_tail = logits_tail + 1.0 * spc.log()
        adj_medium = logits_medium + 0.5 * spc.log() """
        adj_tail = logits_tail 
        adj_medium = logits_medium 

        loss_tail = -torch.sum(F.log_softmax(adj_tail, dim=1) * target_tail) / batch_size
        loss_medium = -torch.sum(F.log_softmax(adj_medium, dim=1) * target_medium) / batch_size

        #### Step 4: 合并损失
        loss = total_loss + loss_tail + loss_medium

        #import pdb; pdb.set_trace()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
def scan_train_LMv1(cfg, train_loader, model, optimizer, epoch,
                        tail_neighbors_idx, medium_neighbors_idx, pseudo_labels, ratio=1):
    """
    SCAN training with label smoothing.
    """
    model.train()

    epsilon = cfg['epsilon']
    entropy_weight = cfg['method_kwargs']['entropy_weight']
    num_class = cfg['backbone']['nclusters']
    
    consistency_indices = []

    for i, batch in enumerate(train_loader):
        anchors = batch['image'].cuda(non_blocking=True)
        neighbors = batch['neighbor'].cuda(non_blocking=True)
        indices = batch['index'].cuda()
        
        anchors_res = model(anchors, forward_pass='return_all')
        anchors_feats, anchors_outputs = anchors_res['features'], anchors_res['output']
        anchors_feats.requires_grad_(True)
        anchors_feats.retain_grad() 
        
        neighbors_res = model(neighbors, forward_pass='return_all')
        neighbors_outputs = neighbors_res['output']

        total_loss = 0
        for out_a, out_n in zip(anchors_outputs, neighbors_outputs):
            scan_loss, cons_loss, ent_loss = SCANLoss(entropy_weight)(out_a, out_n)
            total_loss += scan_loss


        #### 伪标签 one-hot
        batch_size = anchors.shape[0]
        pseudo = torch.tensor(pseudo_labels)[indices.cpu()].long().cuda()
        pseudo_onehot = torch.zeros(batch_size, num_class).cuda()
        pseudo_onehot.scatter_(1, pseudo.view(-1,1), 1)

        target_tail = torch.zeros_like(pseudo_onehot)

        for j in range(batch_size):
            label = pseudo[j].item()
            target_tail[j, label] = 1 - epsilon
            # neighbor-enhanced confidence
            tail_neigh = tail_neighbors_idx[indices[j].item()]

            if len(tail_neigh) > 0:
                target_tail[j, tail_neigh] += epsilon / len(tail_neigh)

        target_tail = target_tail / target_tail.sum(dim=1, keepdim=True)

        logits_tail = model.module.classify_tail(anchors_feats)
        logits_tail.requires_grad_(True)
        logits_tail.retain_grad()

        spc_dict = Counter(pseudo_labels.cpu().numpy())  # {cluster_id: count}
        n_clusters = cfg['backbone']['nclusters']
        spc = [spc_dict.get(i, 1) for i in range(n_clusters)]
        spc = torch.tensor(spc, dtype=torch.float32).cuda()
        adj_tail = logits_tail + 1.0 * spc.log()

        loss_tail = -torch.sum(F.log_softmax(adj_tail, dim=1) * target_tail) / batch_size

        #### Step 4: 合并损失
        loss = total_loss + ratio * loss_tail

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
def scan_train_LMv2(cfg, train_loader, model, optimizer, epoch,
                        tail_neighbors_idx, medium_neighbors_idx, pseudo_labels):
    """
    SCAN training with label smoothing.
    """
    model.train()

    epsilon = cfg['epsilon']
    entropy_weight = cfg['method_kwargs']['entropy_weight']
    num_class = cfg['backbone']['nclusters']

    for i, batch in enumerate(train_loader):
        anchors = batch['image'].cuda(non_blocking=True)
        neighbors = batch['neighbor'].cuda(non_blocking=True)
        indices = batch['index'].cuda()
        
        anchors_res = model(anchors, forward_pass='return_all')
        anchors_feats, anchors_outputs = anchors_res['features'], anchors_res['output']
        neighbors_res = model(neighbors, forward_pass='return_all')
        neighbors_outputs = neighbors_res['output']

        total_loss = 0
        for out_a, out_n in zip(anchors_outputs, neighbors_outputs):
            scan_loss, cons_loss, ent_loss = SCANLoss(entropy_weight)(out_a, out_n)
            total_loss += scan_loss


        #### 伪标签 one-hot
        batch_size = anchors.shape[0]
        pseudo = torch.tensor(pseudo_labels)[indices.cpu()].long().cuda()
        pseudo_onehot = torch.zeros(batch_size, num_class).cuda()
        pseudo_onehot.scatter_(1, pseudo.view(-1,1), 1)

        target_tail = torch.zeros_like(pseudo_onehot)

        for j in range(batch_size):
            label = pseudo[j].item()
            target_tail[j, label] = 1 - epsilon
            # neighbor-enhanced confidence
            tail_neigh = tail_neighbors_idx[indices[j].item()]

            if len(tail_neigh) > 0:
                target_tail[j, tail_neigh] += epsilon / len(tail_neigh)

        target_tail = target_tail / target_tail.sum(dim=1, keepdim=True)

        logits_tail = model.module.classify_tail(anchors_feats)

        spc_dict = Counter(pseudo_labels.cpu().numpy())  # {cluster_id: count}
        n_clusters = cfg['backbone']['nclusters']
        spc = [spc_dict.get(i, 1) for i in range(n_clusters)]
        spc = torch.tensor(spc, dtype=torch.float32).cuda()
        adj_tail = logits_tail + 1.0 * spc.log()

        loss_tail = -torch.sum(F.log_softmax(adj_tail, dim=1) * target_tail) / batch_size
        
        loss = loss_tail

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
              
def scan_train_LMv3(cfg, train_loader, model, optimizer, epoch,
                        tail_neighbors_idx, medium_neighbors_idx, pseudo_labels):
    """
    SCAN training with label smoothing.
    """
    model.train()

    epsilon = cfg['epsilon']
    entropy_weight = cfg['method_kwargs']['entropy_weight']
    num_class = cfg['backbone']['nclusters']

    for i, batch in enumerate(train_loader):
        anchors = batch['image'].cuda(non_blocking=True)
        neighbors = batch['neighbor'].cuda(non_blocking=True)
        indices = batch['index'].cuda()
        
        anchors_res = model(anchors, forward_pass='return_all')
        anchors_feats, anchors_outputs = anchors_res['features'], anchors_res['output']
        neighbors_res = model(neighbors, forward_pass='return_all')
        neighbors_outputs = neighbors_res['output']

        total_loss = 0
        for out_a, out_n in zip(anchors_outputs, neighbors_outputs):
            scan_loss, cons_loss, ent_loss = SCANLoss(entropy_weight)(out_a, out_n)
            total_loss += scan_loss


        #### 伪标签 one-hot
        batch_size = anchors.shape[0]
        pseudo = torch.tensor(pseudo_labels)[indices.cpu()].long().cuda()
        pseudo_onehot = torch.zeros(batch_size, num_class).cuda()
        pseudo_onehot.scatter_(1, pseudo.view(-1,1), 1)

        target_tail = torch.zeros_like(pseudo_onehot)

        for j in range(batch_size):
            label = pseudo[j].item()
            target_tail[j, label] = 1

        target_tail = target_tail / target_tail.sum(dim=1, keepdim=True)

        logits_tail = model.module.classify_tail(anchors_feats)

        adj_tail = logits_tail

        loss_tail = -torch.sum(F.log_softmax(adj_tail, dim=1) * target_tail) / batch_size
        
        loss = loss_tail + total_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        


import numpy as np
from sklearn.neighbors import NearestNeighbors

def compute_density_weights(features, labels, k=10, alpha=0.5, eps=1e-6):
    """
    计算每个样本的权重，密度高的区域权重低
    features: Tensor [N, D]  (可能在 GPU 上)
    labels: Tensor [N] (可能在 GPU 上)
    """
    device = features.device
    features_cpu = features.detach().cpu().numpy()
    labels_cpu = labels.detach().cpu().numpy()

    weights = np.zeros(len(features_cpu), dtype=np.float32)

    dis = np.zeros(len(features_cpu), dtype=np.float32)

    for c in np.unique(labels_cpu):
        idx = np.where(labels_cpu == c)[0]
        cluster_feats = features_cpu[idx]

        if len(idx) <= k:  # 簇太小，不做密度加权
            weights[idx] = 1.0
            dis[idx] = 1.0
            continue

        nbrs = NearestNeighbors(n_neighbors=k+1, algorithm="auto").fit(cluster_feats)
        distances, _ = nbrs.kneighbors(cluster_feats)
        # 去掉自己本身的距离（第一个是0）
        avg_dist = distances[:, 1:].mean(axis=1)

        #pdb.set_trace()

        cluster_weights = np.exp(alpha * avg_dist)
        cluster_weights = cluster_weights / (cluster_weights.sum() + eps)

        weights[idx] = cluster_weights
        dis[idx] = avg_dist

    dis_norm = (dis - dis.min()) / (dis.max() - dis.min())
    counts, bin_edges = np.histogram(dis_norm, bins=50)

    # 转回 torch，并放回原来的 device
    return torch.tensor(weights, dtype=torch.float32, device=device), counts

def weighted_cluster_centers(features, labels, n_clusters, weights):
    """
    用权重重新计算簇中心
    """
    D = features.size(1)
    centers = torch.zeros((n_clusters, D), device=features.device)
    for c in range(n_clusters):
        idx = (labels == c).nonzero(as_tuple=True)[0]
        #pdb.set_trace()
        if len(idx) > 0:
            w = weights[idx].unsqueeze(1)  # [Nc, 1]
            centers[c] = (features[idx] * w).sum(0) / (w.sum() + 1e-6)
    return centers


from cdc.utils.torch_clustering import PyTorchKMeans
def init_head_singlelayer_bias(cfg, model, features, k=10, alpha=1.0):

    # 归一化特征
    features_zscore = (features - features.mean(1, keepdim=True)) / (features.std(1, keepdim=True) + 1e-6)
    features_zscore = F.normalize(features_zscore, dim=1)
    features_zscore = features.detach()

    # KMeans 聚类
    KMeans_c = PyTorchKMeans(init='k-means++', n_clusters=cfg['backbone']['nclusters'], verbose=False, random_state=0)
    class_label = KMeans_c.fit_predict(features_zscore)
    #W1 = KMeans_c.cluster_centers_
    class_label = torch.tensor(class_label, device=features.device)
    density_weights2, counts = compute_density_weights(features_zscore, class_label, k=k, alpha=alpha)
    #pdb.set_trace()
    density_weights2 = density_weights2 * (len(density_weights2) / (density_weights2.sum() + 1e-6))
    #import pdb; pdb.set_trace()
    W1 = weighted_cluster_centers(features_zscore, class_label, cfg['backbone']['nclusters'], density_weights2)

    # 初始化分类头
    with torch.no_grad():
        torch.nn.init.zeros_(model.module.cluster_head[0].bias)
        model.module.cluster_head[0].weight.data = W1.clone()

    print("cluster_head[0] initialized with shape:", W1.shape)  

    density_weights2_np = density_weights2.detach().cpu().numpy()

    counts_low = 0
    counts_mid = 0
    counts_high = 0
    for i in range(len(counts)):
        if i<17:
            counts_high+=counts[i]
        elif i<30:
            counts_mid+=counts[i]
        else:
            counts_low+=counts[i]

    
    indices_per_bin = []
    sorted_idx = np.argsort(-density_weights2_np)
    high_idx = sorted_idx[:counts_high]
    mid_idx  = sorted_idx[counts_high:counts_high+counts_mid]
    low_idx  = sorted_idx[counts_high+counts_mid:counts_high+counts_mid+counts_low]
    indices_per_bin = [low_idx, mid_idx, high_idx]

    # percentiles = np.percentile(density_weights2_np, [20, 40,60,80])
    # indices_per_bin = []
    # for i in range(5):
    #     if i == 0:
    #         mask = density_weights2_np <= percentiles[i]
    #     elif i == 4:
    #         mask = density_weights2_np > percentiles[i-1]
    #     else:
    #         mask = (density_weights2_np > percentiles[i-1]) & (density_weights2_np <= percentiles[i])
    #     bin_indices = np.where(mask)[0]  
    #     indices_per_bin.append(bin_indices)

    return indices_per_bin


def init_head_doublelayer_bias(cfg, model, features, k=10, alpha=1.0):
    features_zscore = (features - features.mean(1, keepdim=True)) / features.std(1, keepdim=True)
    features_zscore = F.normalize(features_zscore, dim=1)  # [N, D]
    KMeans_D = PyTorchKMeans(init='k-means++', n_clusters=features.shape[1], verbose=False)
    proto_labels = KMeans_D.fit_predict(features_zscore)
    W1 = KMeans_D.cluster_centers_  # [D, D]
    H = torch.mm(features, W1.T)
    H = torch.relu(H)
    K = cfg['backbone']['nclusters']
    KMeans_K = PyTorchKMeans(init='k-means++', n_clusters=K, verbose=False)
    cluster_labels = KMeans_K.fit_predict(H)
    #W2 = KMeans_K.cluster_centers_  # [K, D]
    density_weights2 = compute_density_weights(H, cluster_labels, k=k, alpha=alpha)
    #pdb.set_trace()
    density_weights2 = density_weights2 * (len(density_weights2) / (density_weights2.sum() + 1e-6))
    #import pdb; pdb.set_trace()
    W2 = weighted_cluster_centers(H, cluster_labels, cfg['backbone']['nclusters'], density_weights2)

    # 第一层
    torch.nn.init.zeros_(model.module.instance_projector[0].bias)
    model.module.instance_projector[0].weight.data = W1.clone()
    # 第二层
    torch.nn.init.zeros_(model.module.cluster_projector[0].bias)
    model.module.cluster_projector[0].weight.data = W1.clone()  # 同样的 W1 作为 cluster_projector 的第一层
    torch.nn.init.zeros_(model.module.cluster_projector[2].bias)
    model.module.cluster_projector[2].weight.data = W2.clone()  # 第二层为分类原型

    density_weights2_np = density_weights2.detach().cpu().numpy()
    percentiles = np.percentile(density_weights2_np, [20, 40,60,80])
    indices_per_bin = []
    for i in range(5):
        if i == 0:
            mask = density_weights2_np <= percentiles[i]
        elif i == 4:
            mask = density_weights2_np > percentiles[i-1]
        else:
            mask = (density_weights2_np > percentiles[i-1]) & (density_weights2_np <= percentiles[i])

        # 进一步筛选出 target_class 的样本
        #class_mask = (class_label == target_class).cpu().numpy().astype(bool)
        #mask = mask & class_mask

        bin_indices = np.where(mask)[0]  # 该区间内目标类别的 index
        
        # 随机选取10个 index（如果不足10个就全取）
        #chosen = np.random.choice(bin_indices, size=min(10, len(bin_indices)), replace=False)
        indices_per_bin.append(bin_indices)
        return indices_per_bin
    

def init_head_doublelayer_bias_tcl(cfg, model, features, k=10, alpha=1.0):
    # ---------- 特征标准化 ----------
    features_zscore = (features - features.mean(1, keepdim=True)) / (features.std(1, keepdim=True) + 1e-6)
    features_zscore = F.normalize(features_zscore, dim=1)  # [N, D]

    # ---------- 第一层 KMeans ----------
    KMeans_D = PyTorchKMeans(init='k-means++', n_clusters=features.shape[1], verbose=False)
    proto_labels = KMeans_D.fit_predict(features_zscore)
    W1 = torch.tensor(KMeans_D.cluster_centers_, dtype=torch.float32, device=features.device)  # [D, D]

    # ---------- 映射并激活 ----------
    H = torch.mm(features, W1.T)   # [N, D]
    H = torch.relu(H)

    # ---------- 第二层 KMeans ----------
    K = cfg['backbone']['nclusters']
    KMeans_K = PyTorchKMeans(init='k-means++', n_clusters=K, verbose=False)
    cluster_labels = KMeans_K.fit_predict(H)

    K2 = cfg['backbone']['feat_dim']
    KMeans_K2 = PyTorchKMeans(init='k-means++', n_clusters=K2, verbose=False)
    cluster_labels2 = KMeans_K2.fit_predict(H)
    W3 = KMeans_K2.cluster_centers_

    # 计算密度权重
    density_weights2 = compute_density_weights(H, cluster_labels, k=k, alpha=alpha)
    density_weights2 = density_weights2 * (len(density_weights2) / (density_weights2.sum() + 1e-6))

    # 得到加权聚类中心 W2
    W2 = weighted_cluster_centers(H, cluster_labels, K, density_weights2)

    # ---------- 初始化 instance_projector ----------
    # nn.Sequential: [BN, ReLU, Linear(hidden_dim→hidden_dim), BN, ReLU, Linear(hidden_dim→feature_dim)]
    torch.nn.init.zeros_(model.module.instance_projector[2].bias)
    model.module.instance_projector[2].weight.data = W1.clone()  # 第一层线性
    torch.nn.init.zeros_(model.module.instance_projector[5].bias)
    model.module.instance_projector[5].weight.data = W3.clone()  # 第二层线性（用 W1 做投影）

    # ---------- 初始化 cluster_projector ----------
    # nn.Sequential: [BN, ReLU, Linear(hidden_dim→hidden_dim), BN, ReLU, Linear(hidden_dim→cluster_num)]
    torch.nn.init.zeros_(model.module.cluster_projector[2].bias)
    model.module.cluster_projector[2].weight.data = W1.clone()  # 第一层线性
    torch.nn.init.zeros_(model.module.cluster_projector[5].bias)
    model.module.cluster_projector[5].weight.data = W2.clone()  # 最后一层用 W2 (分类原型)

    # ---------- 分桶采样 ----------
    density_weights2_np = density_weights2.detach().cpu().numpy()
    percentiles = np.percentile(density_weights2_np, [20, 40, 60, 80])

    indices_per_bin = []
    for i in range(5):
        if i == 0:
            mask = density_weights2_np <= percentiles[i]
        elif i == 4:
            mask = density_weights2_np > percentiles[i - 1]
        else:
            mask = (density_weights2_np > percentiles[i - 1]) & (density_weights2_np <= percentiles[i])

        bin_indices = np.where(mask)[0]  # 当前区间的 index
        indices_per_bin.append(bin_indices)

    return indices_per_bin