import numpy as np
from typing import List
import os
from loguru import logger
from .base_clusterer import Clusterer
from .cluster_data import ClusterData
from .cluster_based_ref import ClusterBasedRef
from src.utils import save_cache, load_cache
class KMeansClusterer(Clusterer):
    def cluster(self, corpus_emb_1: np.ndarray, d0: np.ndarray, d1: np.ndarray) -> List[ClusterData]:
        num_clusters = self.args.num_clusters
        load_cluster_data_list_key = f"{self.args.dataset_name}_{self.args.model_name_1}_{self.args.d0_ratio:.2f}_{self.args.num_clusters}_cluster_data_list"
        labels_key = f"{self.args.dataset_name}_{self.args.model_name_1}_{self.args.d0_ratio:.2f}_{self.args.num_clusters}_labels"
        if (self.args.load_pretrained_clusters and 
            os.path.exists(f"./cached_output/cluster_data_list/{load_cluster_data_list_key}.npy") and 
            os.path.exists(f"./cached_output/cluster_data_list/{load_cluster_data_list_key}_labels.npy")):
            logger.info("Loading pretrained clusters...")
            cluster_data_list = load_cache("./cached_output/cluster_data_list", load_cluster_data_list_key)
            labels = load_cache("./cached_output/cluster_data_list", labels_key)
            _, _, ref_set_list = ClusterBasedRef(corpus_emb_1, d0, d1, num_clusters=num_clusters).cluster(cluster_data_list, labels)
        else:
            logger.info(f"Running K-means clustering with {num_clusters} clusters...")
            cluster_data_list, labels, ref_set_list = ClusterBasedRef(corpus_emb_1, d0, d1, num_clusters=num_clusters).cluster()
            save_cache("./cached_output/cluster_data_list", load_cluster_data_list_key, cluster_data_list)
            save_cache("./cached_output/cluster_data_list", labels_key, labels)
        return ref_set_list
