"""Clustering analysis module (compute-only; viz moved to pathfmtools.viz)."""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans

from pathfmtools.image.slide_group import SlideGroup

if TYPE_CHECKING:
    import torch

    from pathfmtools.analysis.transforms import EmbeddingTransform
    from pathfmtools.analysis.zeroshot_classification import ZeroShotPatchClassifier
    from pathfmtools.image import Slide

logger = logging.getLogger(__name__)


class PatchEmbeddingClusterer(ABC):
    """Abstract base class for clustering patch embeddings."""

    def __init__(
        self,
        slide: Slide | None = None,
        slide_group: SlideGroup | None = None,
        embedding_type: Literal[
            "patch_feature_embeddings",
            "patch_zeroshot_embeddings",
        ] = "patch_feature_embeddings",
    ) -> None:
        """Initialize the patch embedding clusterer."""
        if not ((slide is not None) ^ (slide_group is not None)):
            msg = "Exactly one of (slide, slide_group) must be provided."
            raise ValueError(msg)
        if slide_group is not None:
            self.slide_group = slide_group
        else:
            self.slide_group = SlideGroup(slide_list=[slide])  # type: ignore[reportArgumentType]
        self.embedding_type = embedding_type

        # _all_cluster_assignments is a 1D numpy array of cluster assignments for all patches
        # in all slides. The cluster assignments are indexed in the same order as the concatenated
        # embeddings.
        self._all_cluster_assignments: np.ndarray | None = None
        # _slide_cluster_assignments is a dictionary keyed by slide ID with values that are
        # 1D numpy arrays of cluster assignments for the patches of a single slide. The cluster
        # assignments are indexed in the same order as the slide's patch embeddings. Flat cluster
        # assignments are grouped into this representation by the SlideGroup object.
        self._slide_cluster_assignments: dict[str, np.ndarray] | None = None
        # _centroid_classification_df is a pandas DataFrame containing zero-shot classification
        # results for the cluster centroids.
        self._centroid_classification_df: pd.DataFrame | None = None

        self._n_clusters: int | None = None

    @property
    def labels(self) -> dict[str, np.ndarray]:
        """Get the labels for the slides."""
        if self._slide_cluster_assignments is None:
            if self._all_cluster_assignments is not None:
                self._slide_cluster_assignments = self.slide_group.map_vals_to_source_patches(
                    self._all_cluster_assignments,
                    embedding_type=self.embedding_type,  # type: ignore[reportArgumentType]
                )
            else:
                msg = "Labels not set. Call fit() first."
                raise ValueError(msg)
        return self._slide_cluster_assignments  # type: ignore[reportReturnStatementType]

    @abstractmethod
    def cluster_feature_embeddings(
        self,
        model_name: str,
        embedding_transforms: list[EmbeddingTransform] | None = None,
        **kwargs,
    ) -> tuple[Any, ...]:
        """Cluster the patch embeddings."""

    def get_slide_cluster_proportions(self) -> dict[str, np.ndarray]:
        """Get the cluster proportions for each slide."""
        if self._n_clusters is None:
            msg = "Number of clusters not set. Call fit() first."
            raise ValueError(msg)
        slide_cluster_proportions = {}
        for slide_id, patch_label_arr in self.labels.items():
            cluster_proportions = {}
            for cluster_id in range(self._n_clusters):
                cluster_proportions[cluster_id] = np.sum(patch_label_arr == cluster_id) / len(
                    patch_label_arr,
                )
            slide_cluster_proportions[slide_id] = (
                pd.Series(cluster_proportions).sort_index(ascending=True).to_numpy()
            )
        return slide_cluster_proportions

    def get_concatenated_embedding_array(
        self,
        model_name: str,
    ) -> np.ndarray:
        """Concatenate patch embeddings for all slides in the slide group into a single array.

        Args:
            model_name (str): The name of the model to use for embedding.

        Returns:
            np.ndarray: A 2D numpy array of patch embeddings with shape (n_patches, embedding_dim).

        """
        return self.slide_group.get_concatenated_embedding_array(
            model_name=model_name,
            embedding_type=self.embedding_type,  # type: ignore[reportArgumentType]
        )


class KMeansPatchClusterer(PatchEmbeddingClusterer):
    """Performs K-means clustering on patch embeddings."""

    def __init__(
        self,
        slide: Slide | None = None,
        slide_group: SlideGroup | None = None,
        embedding_type: Literal[
            "patch_feature_embeddings",
            "patch_zeroshot_embeddings",
        ] = "patch_feature_embeddings",
        zero_shot_classifier: ZeroShotPatchClassifier | None = None,
        centroids: np.ndarray | None = None,
    ) -> None:
        """Initialize the KMeansPatchClusterer.

        Args:
            slide (Slide | None, optional): A single slide to cluster. Defaults to None.
            slide_group (SlideGroup | None, optional): A group of slides to cluster. Defaults to
                None.
            embedding_type (Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"]):
                The type of patch embeddings to use for clustering.
            zero_shot_classifier (ZeroShotPatchClassifier | None, optional): A zero-shot
                classifier to use for classifying the cluster centroids. Defaults to None.
            centroids (np.ndarray | None, optional): The centroids to use for clustering. Defaults
                to None. If provided, the KMeans model will be initialized with these centroids.

        """
        super().__init__(slide=slide, slide_group=slide_group, embedding_type=embedding_type)
        self.zero_shot_classifier = zero_shot_classifier
        self.centroids = centroids

    def cluster_feature_embeddings(
        self,
        model_name: str,
        embedding_transforms: list[EmbeddingTransform] | None = None,
        k: int = 20,
        seed: int | None = None,
        **kwargs,
    ) -> tuple[KMeans, np.ndarray, np.ndarray]:
        """Apply K-means clustering to the patch embeddings.

        Optionally apply a transformation to the patch embeddings prior to clustering.

        Supported transformations:
            - "l2": L2 normalization of the patch embeddings. Note that applying L2 normalization
                to vectors before clustering makes the resulting Euclidean distance metric
                proportional to the cosine distance.
            - "pca": Perform PCA on the patch embeddings and cluster the lower-dimensional
                principal components.

        Args:
            k (int): The number of clusters to form.
            model_name (str): The name of the model to use for embedding.
            embedding_transforms (list[EmbeddingTransform] | None, optional): The transformations
                to apply to the patch embeddings prior to clustering. Defaults to None.
            seed (int | None, optional): The random seed to use for reproducibility. Defaults to
                None.
            **kwargs: Extra keyword args (ignored, only present for compatibility).

        Raises:
            ValueError: If pca_var_threshold is not provided when norm_method is "pca".

        Returns:
            tuple[KMeans, np.ndarray, np.ndarray]: A tuple containing the KMeans model, the
                (potentially transformed) concatenated embedding array, and the cluster
                assignments.

        """
        self._n_clusters = k
        self.model = KMeans(n_clusters=k, random_state=seed)
        embedding_array = self.get_concatenated_embedding_array(model_name=model_name)
        if embedding_transforms is not None:
            for embedding_transform in embedding_transforms:
                embedding_array = embedding_transform.transform(embedding_array)
        self.model.fit(embedding_array)
        self._all_cluster_assignments: np.ndarray = self.model.labels_  # type: ignore[reportAttributeAccessIssue]

        return self.model, embedding_array, self._all_cluster_assignments


    def classify_cluster_centroids(
        self,
        model_name: str,
        classes: dict[str, list[str]],
        device: torch.device,
    ) -> pd.DataFrame:
        """Classify the cluster centroids using a zero-shot classifier.

        Args:
            model_name (str): The name of the model to use for embedding.
            classes (dict[str, list[str]]): A dictionary mapping class groups to lists of class
                names. Class groups represent a category of possible classes (e.g. "Cancer Status"),
                and class names represent the specific classes that the zero-shot classifier will
                classify the cluster centroids into for each class group (e.g. "Positive" and
                "Negative").
            device (str): The torch device to use for classification.

        Raises:
            ValueError: If a zero-shot classifier is not provided.

        Returns:
            pd.DataFrame: A dataframe containing the classification results.

        """
        if self.zero_shot_classifier is None:
            msg = "Classes provided but no zero-shot classifier provided."
            logger.exception(msg)
            raise ValueError(msg)

        classification_results = []
        centroids = self.model.cluster_centers_
        for centroid_idx in range(centroids.shape[0]):
            # Since the centroid embedding is computed as the mean of the patch embeddings,
            # it may not be normalized. Normalize it before using it for zero-shot
            # classification.
            centroid_embedding_normalized = centroids[centroid_idx] / np.linalg.norm(
                centroids[centroid_idx],
            )

            for class_group, class_list in classes.items():
                classification_dict = self.zero_shot_classifier.classify(
                    model_name=model_name,
                    classes=class_list,
                    patch_embedding=centroid_embedding_normalized,
                    device=device,
                )
                for _, class_name in enumerate(class_list):
                    classification_results.append(
                        {
                            "cluster_idx": centroid_idx,
                            "class_group": class_group,
                            "class_name": class_name,
                            "probability": classification_dict["probabilities"][class_name],
                            "logit": classification_dict["logits"][class_name],
                        },
                    )

        classification_df = pd.DataFrame(classification_results)
        self._centroid_classification_df = classification_df

        return classification_df
