import torch
from typing import List, Tuple
from .cluster_data import ClusterData
from sklearn.cluster import KMeans
from loguru import logger
from tqdm import trange
class ClusterBasedRef:
    def __init__(self, all_emb: torch.Tensor, ref_index: torch.Tensor, target_index: torch.Tensor, num_clusters: int = 100):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if not isinstance(all_emb, torch.Tensor):
            all_emb = torch.tensor(all_emb, device=device)
        if not isinstance(ref_index, torch.Tensor):
            ref_index = torch.tensor(ref_index, device=device)
        if not isinstance(target_index, torch.Tensor):
            target_index = torch.tensor(target_index, device=device)
        self.all_emb = all_emb
        self.ref_index = ref_index
        self.target_index = target_index
        self.num_clusters = num_clusters
        logger.info("Initialized ClusterBasedRef with %d clusters.", num_clusters)
    def cluster_ref(self) -> Tuple[List[ClusterData], torch.Tensor]:
        ref_emb = self.all_emb[self.ref_index]
        num_clusters = min(self.num_clusters, ref_emb.shape[0])
        kmeans = KMeans(n_clusters=num_clusters, n_init='auto')
        labels = kmeans.fit_predict(ref_emb.cpu().numpy())
        labels = torch.tensor(labels, device=self.all_emb.device, dtype=torch.int64)
        cluster_data_list = []
        for i in range(num_clusters):
            cluster_indices = torch.where(labels == i)[0]
            cluster_data_list.append(ClusterData(ref_index=self.ref_index[cluster_indices].tolist(), bound_index=[], diameter=0.0))
        return cluster_data_list, labels
    def find_closest_vectors(self, emb: torch.Tensor, labels: torch.Tensor, cluster_data_list: List[ClusterData]) -> List[ClusterData]:
        if not isinstance(emb, torch.Tensor):
            emb = torch.tensor(emb, device=self.all_emb.device)
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels, device=self.all_emb.device, dtype=torch.int64)
        target_vectors = self.all_emb[self.target_index]
        ref_vectors = self.all_emb[self.ref_index]
        batch_size = 100
        num_batches = (ref_vectors.shape[0] + batch_size - 1) // batch_size
        min_cluster = torch.zeros(target_vectors.shape[0], dtype=torch.int64, device=self.all_emb.device)
        min_distance = torch.full((target_vectors.shape[0],), float('inf'), device=self.all_emb.device, dtype=torch.float64)
        for i in trange(num_batches):
            start = i * batch_size
            end = min(start + batch_size, ref_vectors.shape[0])
            ref_vectors_batch = ref_vectors[start:end]
            distances = torch.cdist(target_vectors, ref_vectors_batch)
            closest_ref_distances, closest_ref_indices = torch.min(distances, dim=1)
            closest_ref_indices += start
            indices2cluster = labels[closest_ref_indices]
            mask = closest_ref_distances < min_distance
            min_cluster[mask] = indices2cluster[mask]
            min_distance[mask] = closest_ref_distances[mask].to(min_distance.dtype)
        for j, cluster_data in enumerate(cluster_data_list):
            cluster_data.bound_index = []
            cluster_data.bound_index.extend(list(set(self.target_index[min_cluster == j].tolist())))
        return cluster_data_list
    def cluster(self, cluster_data_list: List[ClusterData] = None, labels: torch.Tensor = None) -> Tuple[List[ClusterData], torch.Tensor, List[ClusterData]]:
        if cluster_data_list is None or labels is None:
            cluster_data_list, labels = self.cluster_ref()
        updated_cluster_data_list = self.find_closest_vectors(self.all_emb, labels, cluster_data_list)
        return cluster_data_list, labels, updated_cluster_data_list
