"""Embedding transformations."""

from abc import ABC, abstractmethod

import numpy as np
from sklearn.decomposition import PCA


class EmbeddingTransform(ABC):
    """Abstract base class for embedding transformations."""

    @abstractmethod
    def transform(
        self,
        embedding: np.ndarray,
    ) -> np.ndarray:
        """Transform embeddings.

        Args:
            embedding (np.ndarray): The embeddings to transform.

        Returns:
            np.ndarray: The transformed embeddings.

        """


class L2Normalization(EmbeddingTransform):
    """Applies L2 normalization to embeddings."""

    def transform(
        self,
        embedding: np.ndarray,
    ) -> np.ndarray:
        """L2 normalize embeddings.

        Args:
            embedding (np.ndarray): The embeddings to normalize.

        Returns:
            np.ndarray: The normalized embeddings.

        """
        norm = np.linalg.norm(embedding, axis=1, keepdims=True)
        norm = np.maximum(norm, 1e-12)
        return embedding / norm


class PCATransformation(EmbeddingTransform):
    """Applies PCA to embeddings.

    Notes:
        - This transformer fits a fresh PCA instance on every call to ``transform``.
          It is intended for one-shot transforms in analysis pipelines, not for
          incremental/online use.
        - Results are deterministic given the same inputs and ordering.

    """

    def __init__(
        self,
        n_components: float,
    ) -> None:
        """Initialize the PCA transformation.

        Args:
            n_components (float): The number of components to keep.

        """
        self.n_components = n_components

    def transform(
        self,
        embedding: np.ndarray,
    ) -> np.ndarray:
        """Apply PCA to embeddings.

        Args:
            embedding (np.ndarray): The embeddings to transform.

        Returns:
            np.ndarray: The transformed embeddings.

        """
        return PCA(n_components=self.n_components).fit_transform(embedding)
