import torch
from sklearn.cluster import KMeans

def cluster_embeddings(embeddings, n_clusters):
    """Perform K-means clustering on embeddings"""
    kmeans = KMeans(n_clusters=n_clusters)
    return kmeans.fit_predict(embeddings.cpu().numpy())

def initialize_cache_centers(text_embeds, visual_embeds, cluster_ids, n_clusters, device):
    """Initialize cache centers based on clustering results"""
    text_cache = torch.zeros(n_clusters, text_embeds.shape[1]).to(device)
    visual_cache = torch.zeros(n_clusters, visual_embeds.shape[1]).to(device)
    
    for i in range(n_clusters):
        mask = (cluster_ids == i)
        if mask.any():
            text_cache[i] = text_embeds[mask].mean(dim=0)
            visual_cache[i] = visual_embeds[mask].mean(dim=0)
    
    # Normalize
    text_cache = torch.nn.functional.normalize(text_cache, dim=-1)
    visual_cache = torch.nn.functional.normalize(visual_cache, dim=-1)
    
    return text_cache, visual_cache
