import numpy as np
from typing import List, Optional
import os
from loguru import logger
from .base_clusterer import Clusterer
from .cluster_data import ClusterData
from .hypergraph_optimizer import HypergraphOptimizer
from src.utils import save_cache, load_cache
class HypergraphClusterer(Clusterer):
    def cluster(self, corpus_emb_1: np.ndarray, d0: np.ndarray, d1: np.ndarray) -> List[ClusterData]:
        cache_key = f"{self.args.dataset_name}_{self.args.model_name_1}_{self.args.model_name_2}_{self.args.d0_ratio:.2f}_{self.args.num_clusters}__ref_set_list"
        cache_dir = "./cached_output/cluster_data_list"
        if self.args.load_pretrained_clusters and os.path.exists(f"{cache_dir}/{cache_key}.npy"):
            logger.info("Loading pretrained hypergraph clusters...")
            ref_set_list = load_cache(cache_dir, cache_key)
        else:
            logger.info("Running hypergraph clustering...")
            non_reference_vectors, cluster_assignments = corpus_emb_1[d1], None
            hypergraph_optimizer = HypergraphOptimizer(
                reference=corpus_emb_1[d0], 
                non_reference=non_reference_vectors, 
                d_prime=self.args.cluster_per_group, 
                reference_index_list=d0, 
                non_reference_index_list=d1
            )
            if self.args.cluster_merge_strategy.lower() == "pre_cluster":
                logger.info(f"Pre-clustering with {self.args.cluster_merge_cluster_num} clusters...")
                cluster_assignments, centers = hypergraph_optimizer.pre_cluster(corpus_emb_1[d1], self.args.cluster_merge_cluster_num)
                non_reference_vectors = centers
            ref_set_list, distance, r2r_distances = hypergraph_optimizer.compute_top_d_neighbors_and_diameters(
                non_reference_vectors, 
                cluster_assignments=cluster_assignments
            )
            if self.args.cluster_merge_strategy.lower() == "diameter":
                logger.info(f"Merging clusters by diameter to {self.args.cluster_merge_cluster_num} clusters...")
                ref_set_list = hypergraph_optimizer.merge_by_diameter(ref_set_list, distance, self.args.cluster_merge_cluster_num)
            save_cache(cache_dir, cache_key, ref_set_list)
        return ref_set_list
