"""Lightweight, compute-only clustering API."""

from __future__ import annotations

from typing import TYPE_CHECKING

from sklearn.cluster import KMeans

if TYPE_CHECKING:
    from collections.abc import Sequence

    import numpy as np

    from .transforms import EmbeddingTransform


def fit_predict(
    X: np.ndarray,  # noqa: N803
    *,
    method: str = "kmeans",
    n_clusters: int = 20,
    seed: int | None = None,
    transforms: Sequence[EmbeddingTransform] | None = None,
    return_model: bool = False,
    **kwargs,
) -> np.ndarray | tuple[np.ndarray, object]:
    """Cluster embeddings and return labels (optionally the fitted model).

    This function performs no visualization and imports no plotting libraries.

    Args:
        X: 2D array of shape (n_samples, n_features).
        method: Clustering method. Currently only "kmeans" is supported.
        n_clusters: Number of clusters for k-means.
        seed: Random seed for determinism.
        transforms: Optional sequence of embedding transforms to apply in order.
        return_model: If True, also return the fitted model instance.
        **kwargs: Extra keyword args forwarded to the estimator constructor.

    Returns:
        labels or (labels, model) if return_model is True.

    """
    if X.ndim != 2:
        msg = "X must be a 2D array of shape (n_samples, n_features)"
        raise ValueError(msg)

    x_proc = X
    if transforms is not None:
        for t in transforms:
            x_proc = t.transform(x_proc)

    if method.lower() == "kmeans":
        model = KMeans(n_clusters=n_clusters, random_state=seed, **kwargs)
    else:
        msg = f"Unsupported method: {method}"
        raise ValueError(msg)

    labels = model.fit_predict(x_proc)
    return (labels, model) if return_model else labels


__all__ = ["fit_predict"]
