import torch
import numpy as np
from typing import List, Optional, Set
from loguru import logger
from tqdm import tqdm, trange
import heapq
from .base_clusterer import Clusterer
from .cluster_data import ClusterData
class AvgLinkageCluster:
    def __init__(self, num_clusters: int, min_cluster_size: int, 
                 corpus_emb: torch.Tensor, d0: List[int], d1: List[int]):
        self.num_clusters = num_clusters
        self.min_cluster_size = min_cluster_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.corpus_emb = torch.as_tensor(corpus_emb, device=self.device)
        self.d0 = d0
        self.d1 = d1
    def cluster(self) -> List[ClusterData]:
        input_tensor = torch.as_tensor(self.corpus_emb[self.d0], device=self.device)
        clusters = [[i] for i in range(len(self.corpus_emb[self.d0]))]
        dist_matrix = torch.cdist(input_tensor, input_tensor)
        active_nodes: Set[int] = set(range(len(clusters)))
        ref_set_list: List[ClusterData] = []
        tree_nodes = []
        logger.info("Initializing priority queue for average linkage clustering")
        distance_queue = []
        for i in trange(len(clusters)):
            for j in range(i + 1, len(clusters)):
                avg_dist = dist_matrix[clusters[i]][:, clusters[j]].mean()
                heapq.heappush(distance_queue, (float(avg_dist), i, j))
        total_iterations = len(active_nodes) - self.num_clusters
        pbar = tqdm(total=total_iterations, desc="Clustering progress")
        while distance_queue:
            valid_pair = False
            while distance_queue:
                min_dist, merge_i, merge_j = heapq.heappop(distance_queue)
                pbar.update(1)
                if merge_i in active_nodes and merge_j in active_nodes:
                    valid_pair = True
                    break
            if not valid_pair:
                break
            merged_indices = clusters[merge_i] + clusters[merge_j]
            if len(merged_indices) > self.min_cluster_size:
                ref_set_list.append(ClusterData(
                    ref_index=self.d0[merged_indices],
                    bound_index=[],
                    diameter=float(min_dist)
                ))
                active_nodes.remove(merge_i)
                active_nodes.remove(merge_j)
                continue
            new_node_idx = len(clusters)
            tree_nodes.append({
                'children': [merge_i, merge_j],
                'indices': merged_indices
            })
            clusters.append(merged_indices)
            active_nodes.remove(merge_i)
            active_nodes.remove(merge_j)
            active_nodes.add(new_node_idx)
            for node in active_nodes:
                if node != new_node_idx:
                    avg_dist = dist_matrix[clusters[new_node_idx]][:, clusters[node]].mean()
                    heapq.heappush(distance_queue, (
                        float(avg_dist),
                        min(new_node_idx, node),
                        max(new_node_idx, node)
                    ))
        pbar.close()
        remaining_indices = []
        current_size = 0
        sorted_nodes = sorted(active_nodes)
        for node_idx in sorted_nodes:
            node_indices = clusters[node_idx]
            remaining_indices.extend(node_indices)
            current_size += len(node_indices)
            if current_size >= self.min_cluster_size:
                ref_set_list.append(ClusterData(
                    ref_index=self.d0[remaining_indices],
                    bound_index=[],
                    diameter=0.0
                ))
                remaining_indices = []
                current_size = 0
        if remaining_indices:
            if ref_set_list:
                last_cluster = ref_set_list[-1]
                combined_indices = np.concatenate([last_cluster.ref_index, self.d0[remaining_indices]])
                ref_set_list[-1] = ClusterData(
                    ref_index=combined_indices,
                    bound_index=[],
                    diameter=0.0
                )
            else:
                ref_set_list.append(ClusterData(
                    ref_index=self.d0[remaining_indices],
                    bound_index=[],
                    diameter=0.0
                ))
        logger.info(f"Created {len(ref_set_list)} clusters")
        return self.bound_non_ref_set(ref_set_list)
    def bound_non_ref_set(self, ref_set_list: List[ClusterData]) -> List[ClusterData]:
        non_ref_vectors = torch.as_tensor(self.corpus_emb[self.d1], device=self.device)
        ref_vectors = torch.as_tensor(self.corpus_emb[self.d0], device=self.device)
        ref_to_seq = {self.d0[i]: i for i in range(len(self.d0))}
        ref_to_cluster = {ref_to_seq[ref_idx]: cluster_idx 
                         for cluster_idx, cluster in enumerate(ref_set_list)
                         for ref_idx in cluster.ref_index}
        batch_size = 100
        num_batches = (non_ref_vectors.shape[0] + batch_size - 1) // batch_size
        min_cluster = torch.zeros(non_ref_vectors.shape[0], dtype=torch.int64, device=self.device)
        min_distance = torch.full((non_ref_vectors.shape[0],), float('inf'), device=self.device)
        for i in trange(num_batches, desc="Assigning non-reference vectors to clusters"):
            start = i * batch_size
            end = min(start + batch_size, non_ref_vectors.shape[0])
            non_ref_batch = non_ref_vectors[start:end]
            distances = torch.cdist(non_ref_batch, ref_vectors)
            closest_distances, closest_indices = torch.min(distances, dim=1)
            cluster_indices = torch.tensor([ref_to_cluster[idx.item()] 
                                         for idx in closest_indices], device=self.device)
            min_cluster[start:end] = cluster_indices
            min_distance[start:end] = closest_distances
        for j, cluster_data in enumerate(ref_set_list):
            mask = (min_cluster == j).cpu()
            cluster_data.bound_index = self.d1[mask].tolist()
        return ref_set_list 
class AvgLinkageClusterer(Clusterer):
    def cluster(self, corpus_emb_1: np.ndarray, d0: np.ndarray, d1: np.ndarray) -> List[ClusterData]:
        logger.info(f"Running average linkage clustering with {self.args.num_clusters} clusters...")
        avg_linkage_cluster = AvgLinkageCluster(
            num_clusters=self.args.num_clusters,
            min_cluster_size=self.args.reduced_dim,
            corpus_emb=corpus_emb_1,
            d0=d0,
            d1=d1
        )
        ref_set_list = avg_linkage_cluster.cluster()
        all_ref_indices = np.concatenate([ref_set.ref_index for ref_set in ref_set_list])
        assert np.array_equal(np.sort(all_ref_indices), np.sort(np.array(d0))), \
            "The combined ref_indices do not match d0"
        all_bound_indices = np.concatenate([ref_set.bound_index for ref_set in ref_set_list])
        assert np.array_equal(np.sort(all_bound_indices), np.sort(np.array(d1))), \
            "The combined bound_indices do not match d1"
        logger.info("Average linkage clustering completed successfully")
        return ref_set_list
