from __future__ import annotations

import abc
from dataclasses import dataclass

import numpy as np

from softmatcha.typing import Vector


@dataclass
class Embedding(abc.ABC):
    """Embedding base class.

    embeddings (Vector): Embedding vectors.
    """

    embeddings: Vector

    def __len__(self) -> int:
        return len(self.embeddings)

    @classmethod
    def build(cls, name_or_path: str, mmap: bool = False) -> Embedding:
        """Build an embedding class.

        Args:
            name_or_path (str): Model name or path.
            mmap (bool): Open the embedding file via mmap.

        Returns:
            Embedding: This class.
        """
        return cls(cls.load(name_or_path, mmap=mmap))

    @classmethod
    @abc.abstractmethod
    def load(cls, name_or_path: str, mmap: bool = False) -> Vector:
        """Load an embedding table.

        Args:
            name_or_path (str): Model name or path.
            mmap (bool): Open the embedding file via mmap.

        Returns:
            Vector: The embedding table.
        """

    def __call__(self, tokens: list[int]) -> Vector:
        """Embed tokens into their vector representations.

        Args:
            tokens (list[int]): Input tokens.

        Returns:
            Vector: Token embeddings.
        """
        if len(tokens) == 0:
            return np.array([], dtype=np.float32)
        return self.embeddings[tokens]
