import typing

from scripts.baseline_kmeans.discretizers.interface import ClusteringDiscretizer
from scripts.baseline_kmeans.discretizers.faiss_kmeans import (
    ClusteringDiscretizerFaissKMeans,
)

clustering_model_choices = ["scikit", "cuml", "faiss"]
ClusteringModels = typing.Literal["scikit", "cuml", "faiss"]


def get_clustering_discretizer(
    model: ClusteringModels, **kwargs
) -> "ClusteringDiscretizer":
    """
    Factory function to get the appropriate ClusteringDiscretizer implementation.

    Args:
        model (str): The algorithm name, e.g., 'scikit'.
        kwargs: Additional parameters passed to the ClusteringDiscretizer constructor.

    Returns:
        An instance of ClusteringDiscretizer.
    """

    # there are import conflicts with FAISS and scikit learn, hence the local imports

    if model == "scikit":
        from scripts.baseline_kmeans.discretizers.scikit_kmeans import (
            ClusteringDiscretizerScikitKMeans,
        )

        return ClusteringDiscretizerScikitKMeans(**kwargs)
    elif model == "cuml":
        from scripts.baseline_kmeans.discretizers.cuml_kmeans import (
            ClusteringDiscretizerCuMLKMeans,
        )

        return ClusteringDiscretizerCuMLKMeans(**kwargs)
    elif model == "faiss":
        return ClusteringDiscretizerFaissKMeans(**kwargs)
    else:
        raise ValueError(f"Unsupported algorithm: {model}")
