from .base_clusterer import Clusterer
from .cluster_data import ClusterData
import numpy as np
from typing import List
import os
from src.utils import load_cache, save_cache
from loguru import logger
from sklearn.cluster import KMeans
from typing import List, Dict, Tuple
import numpy as np
from tqdm import trange, tqdm
from dataclasses import dataclass
import torch
from typing import List, Tuple, Optional
class HypergraphOptimizer:
    def __init__(self, reference: torch.Tensor, non_reference: torch.Tensor, 
                 d_prime: int, reference_index_list: List[int], 
                 non_reference_index_list: List[int]):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.reference = torch.as_tensor(reference, device=device)
        self.non_reference = torch.as_tensor(non_reference, device=device)
        self.d_prime = d_prime
        self.reference_index_list = reference_index_list
        self.non_reference_index_list = non_reference_index_list
        self.r_index2range: Optional[dict] = None
    def pre_cluster(self, vectors: torch.Tensor, num_clusters: int) -> Tuple[torch.Tensor, torch.Tensor]:
        logger.info(f"Pre-clustering vectors into {num_clusters} clusters")
        kmeans = KMeans(n_clusters=num_clusters, n_init='auto')
        cluster_assignments = kmeans.fit_predict(vectors.cpu().numpy() if torch.is_tensor(vectors) else vectors)
        centers = torch.tensor(kmeans.cluster_centers_, device=self.non_reference.device)
        cluster_assignments = torch.tensor(cluster_assignments, device=self.non_reference.device)
        return cluster_assignments, centers
    def compute_top_d_neighbors_and_diameters(self, 
                                            non_reference_vectors: Optional[torch.Tensor] = None,
                                            cluster_assignments: Optional[torch.Tensor] = None) -> Tuple[List[ClusterData], List[float], torch.Tensor]:
        if non_reference_vectors is None:
            non_reference_vectors = self.non_reference
        non_reference_vectors = torch.as_tensor(non_reference_vectors, device=self.reference.device)
        r2r_distances = torch.cdist(self.reference, self.reference)
        self.r2r_distances = r2r_distances
        neighborhoods = []
        diameters = []
        for i in range(non_reference_vectors.shape[0]):
            if cluster_assignments is not None:
                cluster_indices = (cluster_assignments == i).nonzero().squeeze().tolist()
                cluster_indices = [cluster_indices] if isinstance(cluster_indices, int) else cluster_indices
                bound_indices = [self.non_reference_index_list[idx] for idx in cluster_indices]
                new_cluster = ClusterData(ref_index=[], bound_index=bound_indices, diameter=0.0)
            else:
                new_cluster = ClusterData(ref_index=[], 
                                        bound_index=[self.non_reference_index_list[i]], 
                                        diameter=0.0)
            distances = torch.cdist(non_reference_vectors[i].unsqueeze(0), self.reference)
            top_d_indices = torch.argsort(distances.squeeze())[:self.d_prime].tolist()
            new_cluster.ref_index = [self.reference_index_list[idx] for idx in top_d_indices]
            max_diameter = torch.max(r2r_distances[top_d_indices, :][:, top_d_indices]).item()
            new_cluster.diameter = max_diameter
            neighborhoods.append(new_cluster)
            diameters.append(max_diameter)
        return neighborhoods, diameters, r2r_distances
    def merge_and_sort_neighborhoods(self, neighborhoods: List[ClusterData], 
                                   diameters: List[float]) -> List[ClusterData]:
        merged_neighborhoods = []
        ref_index_map = {}
        for neighborhood in neighborhoods:
            ref_index_set = frozenset(neighborhood.ref_index)
            if ref_index_set not in ref_index_map:
                merged_neighborhoods.append(neighborhood)
                ref_index_map[ref_index_set] = neighborhood
            else:
                existing = ref_index_map[ref_index_set]
                existing.bound_index.extend(neighborhood.bound_index)
        return sorted(merged_neighborhoods, key=lambda x: x.diameter, reverse=True)
    def construct_hypergraph(self, sorted_neighborhoods: List[ClusterData]) -> List[Tuple[int, int]]:
        edges = []
        max_size = max(len(Si.ref_index) for Si in sorted_neighborhoods)
        neighborhood_tensors = torch.zeros((len(sorted_neighborhoods), max_size), 
                                         dtype=torch.int64, 
                                         device=self.reference.device)
        for idx, Si in enumerate(sorted_neighborhoods):
            neighborhood_tensors[idx, :len(Si.ref_index)] = torch.tensor(Si.ref_index, 
                                                                        device=self.reference.device)
        for i in tqdm(range(len(neighborhood_tensors)), desc="Constructing hypergraph"):
            Si_set = set(neighborhood_tensors[i].tolist())
            for j in range(i):
                Sj_set = set(neighborhood_tensors[j].tolist())
                if Si_set & Sj_set:
                    edges.extend([(i, j), (j, i)])
        return list(set(edges))
    def optimize_neighborhood_sets(self, sorted_neighborhoods: List[ClusterData], 
                                 edges: List[Tuple[int, int]], 
                                 diameters: List[float], 
                                 r2r_distances: torch.Tensor) -> List[ClusterData]:
        if not hasattr(self, 'r_index2range'):
            self.r_index2range = {y: x for x, y in enumerate(self.reference_index_list)}
        for i, j in tqdm(edges, desc="Optimizing neighborhood sets"):
            if sorted_neighborhoods[i] is None or sorted_neighborhoods[j] is None:
                continue
            new_indices = [self.r_index2range[idx] for idx in 
                         sorted_neighborhoods[i].ref_index + sorted_neighborhoods[j].ref_index]
            new_diameter = torch.max(r2r_distances[new_indices, :][:, new_indices]).item()
            if new_diameter <= diameters[i] * 1.0:
                sorted_neighborhoods[i].ref_index.extend(sorted_neighborhoods[j].ref_index)
                sorted_neighborhoods[i].bound_index.extend(sorted_neighborhoods[j].bound_index)
                sorted_neighborhoods[i].diameter = new_diameter
                sorted_neighborhoods[j] = None
                diameters[i] = new_diameter
        return [s for s in sorted_neighborhoods if s is not None]
    def merge_by_diameter(self, neighborhoods: List[ClusterData], 
                         distance: List[float], 
                         num_clusters: int) -> List[ClusterData]:
        if not hasattr(self, 'r_index2range'):
            self.r_index2range = {y: x for x, y in enumerate(self.reference_index_list)}
        def calculate_cluster_distance(cluster1: ClusterData, cluster2: ClusterData) -> float:
            ref1 = torch.tensor([self.r_index2range[idx] for idx in cluster1.ref_index], 
                              device=self.r2r_distances.device)
            ref2 = torch.tensor([self.r_index2range[idx] for idx in cluster2.ref_index], 
                              device=self.r2r_distances.device)
            return torch.mean(self.r2r_distances[ref1][:, ref2]).item()
        clusters = [c for c in neighborhoods if c is not None]
        if not clusters:
            return clusters
        num_clusters = max(1, min(num_clusters, len(clusters)))
        n = len(clusters)
        dist_matrix = torch.full((n, n), float('inf'), device=self.r2r_distances.device)
        for i in trange(n, desc="Computing initial distances"):
            for j in range(i + 1, n):
                dist = calculate_cluster_distance(clusters[i], clusters[j])
                dist_matrix[i, j] = dist_matrix[j, i] = dist
        while len(clusters) > num_clusters:
            min_dist = torch.min(dist_matrix).item()
            if min_dist == float('inf'):
                break
            min_idx = torch.argmin(dist_matrix).item()
            i, j = min_idx // len(dist_matrix), min_idx % len(dist_matrix)
            combined_refs = [self.r_index2range[idx] for idx in 
                           clusters[i].ref_index + clusters[j].ref_index]
            combined_refs = torch.tensor(combined_refs, device=self.r2r_distances.device)
            new_diameter = torch.max(self.r2r_distances[combined_refs][:, combined_refs]).item()
            clusters[i].ref_index.extend(clusters[j].ref_index)
            clusters[i].bound_index.extend(clusters[j].bound_index)
            clusters[i].diameter = new_diameter
            clusters.pop(j)
            mask = torch.ones(len(dist_matrix), dtype=torch.bool, device=dist_matrix.device)
            mask[j] = False
            dist_matrix = dist_matrix[mask][:, mask]
            for k in range(len(clusters)):
                if k != i:
                    dist = calculate_cluster_distance(clusters[i], clusters[k])
                    dist_matrix[i, k] = dist_matrix[k, i] = dist
            logger.info(f"Merged clusters: remaining {len(clusters)} clusters, min_dist: {min_dist:.4f}")
        return clusters
    def hypergraph_construction_and_optimization(self) -> List[ClusterData]:
        neighborhoods, diameters, r2r_distances = self.compute_top_d_neighbors_and_diameters()
        logger.info(f"Initial neighborhoods: {len(neighborhoods)}")
        sorted_neighborhoods = self.merge_and_sort_neighborhoods(neighborhoods, diameters)
        logger.info(f"Merged neighborhoods: {len(sorted_neighborhoods)}")
        edges = self.construct_hypergraph(sorted_neighborhoods)
        logger.info(f"Hypergraph edges: {len(edges)}")
        optimized_sets = self.optimize_neighborhood_sets(sorted_neighborhoods, edges, diameters, r2r_distances)
        logger.info(f"Final optimized sets: {len(optimized_sets)}")
        return optimized_sets
class AvgLinkageCluster:
    def __init__(self, num_clusters: int, min_cluster_size: int, corpus_emb: torch.Tensor, d0: List[int], d1: List[int]):
        self.num_clusters = num_clusters
        self.min_cluster_size = min_cluster_size
        self.corpus_emb = corpus_emb
        self.d0 = d0
        self.d1 = d1
    def cluster(self) -> List[ClusterData]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        input_tensor = torch.tensor(self.corpus_emb[self.d0], device=device)
        clusters = [[i] for i in range(len(self.corpus_emb[self.d0]))]
        dist_matrix = torch.cdist(input_tensor, input_tensor)
        tree_nodes = []
        active_nodes = set(range(len(clusters)))
        ref_set_list = []
        logger.info("Initializing priority queue for avg linkage clustering")
        import heapq
        distance_queue = []
        for i in trange(len(clusters)):
            for j in range(i + 1, len(clusters)):
                avg_dist = dist_matrix[clusters[i]][:, clusters[j]].mean()
                heapq.heappush(distance_queue, (float(avg_dist), i, j))
        total_iterations = len(active_nodes) - self.num_clusters
        pbar = tqdm(total=total_iterations, desc="Clustering progress")
        processed_pairs = 0
        while distance_queue:
            valid_pair_found = False
            while distance_queue:
                min_dist, merge_i, merge_j = heapq.heappop(distance_queue)
                processed_pairs += 1
                pbar.update(1)
                if merge_i in active_nodes and merge_j in active_nodes:
                    valid_pair_found = True
                    break
            if not valid_pair_found:
                break
            merged_indices = clusters[merge_i] + clusters[merge_j]
            if len(merged_indices) > self.min_cluster_size:
                ref_set_list.append(ClusterData(
                    ref_index=self.d0[merged_indices],
                    bound_index=[],
                    diameter=float(min_dist)
                ))
                active_nodes.remove(merge_i)
                active_nodes.remove(merge_j)
                continue
            new_node_idx = len(clusters)
            tree_nodes.append({
                'children': [merge_i, merge_j],
                'indices': merged_indices
            })
            clusters.append(merged_indices)
            active_nodes.remove(merge_i)
            active_nodes.remove(merge_j)
            active_nodes.add(new_node_idx)
            for node in active_nodes:
                if node != new_node_idx:
                    avg_dist = dist_matrix[clusters[new_node_idx]][:, clusters[node]].mean()
                    heapq.heappush(distance_queue, (
                        float(avg_dist),
                        min(new_node_idx, node),
                        max(new_node_idx, node)
                    ))
        pbar.close()
        remaining_indices = []
        current_size = 0
        sorted_nodes = sorted(active_nodes)
        for node_idx in sorted_nodes:
            node_indices = clusters[node_idx]
            remaining_indices.extend(node_indices)
            current_size += len(node_indices)
            if current_size >= self.min_cluster_size:
                ref_set_list.append(ClusterData(
                    ref_index=self.d0[remaining_indices],
                    bound_index=[],
                    diameter=0.0
                ))
                remaining_indices = []
                current_size = 0
        if remaining_indices:
            if ref_set_list:
                last_cluster = ref_set_list[-1]
                combined_indices = np.concatenate([last_cluster.ref_index, self.d0[remaining_indices]])
                ref_set_list[-1] = ClusterData(
                    ref_index=combined_indices,
                    bound_index=[],
                    diameter=0.0
                )
            else:
                ref_set_list.append(ClusterData(
                    ref_index=self.d0[remaining_indices],
                    bound_index=[],
                    diameter=0.0
                ))
        logger.info(f"Created {len(ref_set_list)} clusters")
        return self.bound_non_ref_set(ref_set_list)
    def bound_non_ref_set(self, ref_set_list: List[ClusterData]) -> List[ClusterData]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        non_ref_vectors = torch.tensor(self.corpus_emb[self.d1], device=device)
        ref_vectors = torch.tensor(self.corpus_emb[self.d0], device=device)
        ref_to_seq = {self.d0[i]: i for i in range(len(self.d0))}
        seq_to_ref = {i: self.d0[i] for i in range(len(self.d0))}
        ref_to_cluster = {}
        for cluster_idx, cluster in enumerate(ref_set_list):
            for ref_idx in cluster.ref_index:
                ref_to_cluster[ref_to_seq[ref_idx]] = cluster_idx
        batch_size = 100
        num_batches = (non_ref_vectors.shape[0] + batch_size - 1) // batch_size
        min_cluster = torch.zeros(non_ref_vectors.shape[0], dtype=torch.int64, device=device)
        min_distance = torch.full((non_ref_vectors.shape[0],), float('inf'), device=device)
        for i in trange(num_batches, desc="Assigning non-reference vectors to clusters"):
            start = i * batch_size
            end = min(start + batch_size, non_ref_vectors.shape[0])
            non_ref_batch = non_ref_vectors[start:end]
            distances = torch.cdist(non_ref_batch, ref_vectors)
            closest_distances, closest_indices = torch.min(distances, dim=1)
            cluster_indices = torch.tensor([ref_to_cluster[idx.item()] for idx in closest_indices], 
                                        device=device)
            min_cluster[start:end] = cluster_indices
            min_distance[start:end] = closest_distances
        for j, cluster_data in enumerate(ref_set_list):
            mask = (min_cluster == j).cpu()
            cluster_data.bound_index = self.d1[mask].tolist()
        return ref_set_list
class HypergraphClusterer(Clusterer):
    def cluster(self, corpus_emb_1: np.ndarray, d0: np.ndarray, d1: np.ndarray) -> List[ClusterData]:
        cache_key = f"{self.args.dataset_name}_{self.args.model_name_1}_{self.args.model_name_2}_{self.args.d0_ratio:.2f}_{self.args.num_clusters}__ref_set_list"
        cache_dir = "./cached_output/cluster_data_list"
        if self.args.load_pretrained_clusters and os.path.exists(f"{cache_dir}/{cache_key}.npy"):
            ref_set_list = load_cache(cache_dir, cache_key)
        else:
            non_reference_vectors, cluster_assignments = corpus_emb_1[d1], None
            hypergraph_optimizer = HypergraphOptimizer(
                reference=corpus_emb_1[d0], 
                non_reference=non_reference_vectors, 
                d_prime=self.args.cluster_per_group, 
                reference_index_list=d0, 
                non_reference_index_list=d1
            )
            if self.args.cluster_merge_strategy.lower() == "pre_cluster":
                cluster_assignments, centers = hypergraph_optimizer.pre_cluster(corpus_emb_1[d1], self.args.cluster_merge_cluster_num)
                non_reference_vectors = centers
            ref_set_list, distance, r2r_distances = hypergraph_optimizer.compute_top_d_neighbors_and_diameters(non_reference_vectors, cluster_assignments=cluster_assignments)
            if self.args.cluster_merge_strategy.lower() == "diameter":
                ref_set_list = hypergraph_optimizer.merge_by_diameter(ref_set_list, distance, self.args.cluster_merge_cluster_num)
            save_cache(cache_dir, cache_key, ref_set_list)
        return ref_set_list
