from collections.abc import Mapping
from enum import Enum
from types import MappingProxyType
from typing import Any

from .base import GrammaticalEncoder
from .complex import GrammaticalComplexEncoder
from .identity import IdentityEncoder
from .sum_embedding import GrammaticalSumEmbedder

__all__ = [
    "GrammaticalComplexEncoder",
    "GrammaticalEncoder",
    "GrammaticalEncoderFactory",
    "GrammaticalEncoderType",
    "GrammaticalSumEmbedder",
    "IdentityEncoder",
]


class GrammaticalEncoderType(Enum):
    COMPLEX = "complex"
    IDENTITY = "identity"
    SUM = "sum"


class GrammaticalEncoderFactory:
    """Factory for creating grammatical encoder instances."""

    _registry: Mapping[GrammaticalEncoderType, type[GrammaticalEncoder]] = MappingProxyType(
        {
            GrammaticalEncoderType.COMPLEX: GrammaticalComplexEncoder,
            GrammaticalEncoderType.IDENTITY: IdentityEncoder,
            GrammaticalEncoderType.SUM: GrammaticalSumEmbedder,
        },
    )

    @classmethod
    def get_grammatical_encoder(
        cls,
        encoder_type: GrammaticalEncoderType,
        dimension: int | None = None,
        **kwargs: Any,
    ) -> GrammaticalEncoder:
        """Create a grammatical encoder instance of the specified type.

        Args:
            encoder_type: Type of grammatical encoder to create
            dimension: Dimension of the feature vectors (required for some encoders)
            **kwargs: Additional configuration arguments

        Returns:
            An instance of the specified grammatical encoder type

        Raises:
            ValueError: If the encoder type is not registered

        """
        if encoder_type not in cls._registry:
            msg = f"Unknown grammatical encoder type: {encoder_type}"
            raise ValueError(msg)

        encoder_class = cls._registry[encoder_type]

        # Add dimension for encoders that need it
        if encoder_type == GrammaticalEncoderType.SUM:
            if dimension is None:
                msg = f"Dimension must be specified for encoder type: {encoder_type}"
                raise ValueError(msg)
            kwargs["dim"] = dimension

        return encoder_class(**kwargs)
