import torch
from sklearn.cluster import KMeans
from typing import Dict, List, Tuple
from sklearn.metrics.pairwise import cosine_similarity
import logging


def cluster_embeddings(embeddings, n_clusters):
    """Perform K-means clustering on embeddings"""
    kmeans = KMeans(n_clusters=n_clusters)
    return kmeans.fit_predict(embeddings.cpu().numpy())
def initialize_cache_centers(
    text_embeds: torch.Tensor,
    visual_embeds: torch.Tensor,
    cluster_ids: torch.Tensor,
    cache_size: int,
    device: str
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Initialize cache centers for text and visual features based on clustering results.
    If the number of clusters is smaller than cache_size, unassigned slots are filled with random vectors.
    """
    D = text_embeds.shape[1]
    text_cache = torch.zeros((cache_size, D), device=device)
    visual_cache = torch.zeros((cache_size, D), device=device)

    unique_clusters = torch.unique(cluster_ids).tolist()
    for cluster_id in unique_clusters:
        mask = (cluster_ids == cluster_id)
        if mask.sum() > 0:
            text_cache[cluster_id] = text_embeds[mask].mean(dim=0)
            visual_cache[cluster_id] = visual_embeds[mask].mean(dim=0)

    # Fill uninitialized slots with random vectors
    for i in range(cache_size):
        if torch.norm(text_cache[i]) < 1e-5:
            text_cache[i] = torch.randn(D, device=device)
            visual_cache[i] = torch.randn(D, device=device)

    # Normalize for better similarity comparison
    text_cache = torch.nn.functional.normalize(text_cache, dim=-1)
    visual_cache = torch.nn.functional.normalize(visual_cache, dim=-1)
    
    # Log how many cache slots were randomly initialized
    num_random_text = (text_cache.norm(dim=1) < 1e-5).sum().item()
    num_random_visual = (visual_cache.norm(dim=1) < 1e-5).sum().item()
    logging.info(f"[InitCache] Randomly filled {num_random_text} text slots and {num_random_visual} visual slots.")


    return text_cache, visual_cache





