from __future__ import annotations

from dataclasses import dataclass

from softmatcha.typing import Vector

from . import register
from .base import Embedding


@register("transformers")
@dataclass
class EmbeddingTransformers(Embedding):
    """EmbeddingTransformers class wraps encoder models of huggingface/transformers.

    embeddings (Vector): Embedding vectors.
    """

    @classmethod
    def load(cls, name_or_path: str, mmap: bool = False) -> Vector:
        """Load an embedding table.

        Args:
            name_or_path (str): Model name or path.
            mmap (bool): Open the embedding file via mmap.

        Returns:
            Vector: The embedding table.
        """
        from transformers import AutoModel

        import softmatcha.modules.functional as F

        embed = AutoModel.from_pretrained(name_or_path).get_input_embeddings()
        embed = embed.weight.detach().numpy()
        return F.normalize(embed, axis=-1)
