from abc import ABC, abstractmethod
import numpy as np
from typing import List
import argparse
from dataclasses import dataclass
from .cluster_data import ClusterData
@dataclass
class ClusteringConfig:
    num_clusters: int
    reduced_dim: int
    dataset_name: str
    model_name_1: str
    model_name_2: str
    d0_ratio: float
    load_pretrained_clusters: bool = False
    cluster_per_group: int = 10
    cluster_merge_strategy: str = "diameter"
    cluster_merge_cluster_num: int = 100
class Clusterer(ABC):
    def __init__(self, args: argparse.Namespace):
        self.args = args
    @abstractmethod
    def cluster(self, corpus_emb_1: np.ndarray, d0: np.ndarray, d1: np.ndarray) -> List[ClusterData]:
        raise NotImplementedError("Concrete clustering classes must implement cluster()")
    def validate_inputs(self, corpus_emb_1: np.ndarray, d0: np.ndarray, d1: np.ndarray) -> None:
        if not isinstance(corpus_emb_1, np.ndarray):
            raise ValueError("corpus_emb_1 must be a numpy array")
        if not isinstance(d0, np.ndarray):
            raise ValueError("d0 must be a numpy array")
        if not isinstance(d1, np.ndarray):
            raise ValueError("d1 must be a numpy array")
        if len(d0) + len(d1) != len(corpus_emb_1):
            raise ValueError("Sum of d0 and d1 lengths must equal corpus_emb_1 length")
        if len(set(d0) & set(d1)) > 0:
            raise ValueError("d0 and d1 must not have overlapping indices")
        if len(d0) < self.args.num_clusters:
            raise ValueError("Number of reference points must be >= number of clusters")
