from __future__ import annotations

import os.path
from dataclasses import dataclass

import numpy as np

from softmatcha.typing import Vector

from . import register
from .base import Embedding


@register("gensim")
@dataclass
class EmbeddingGensim(Embedding):
    """EmbeddingGensim class wraps various gensim models, e.g., glove.

    embeddings (Vector): Embedding vectors.
    """

    @classmethod
    def load(cls, name_or_path: str, mmap: bool = False) -> Vector:
        """Load an embedding table.

        Args:
            name_or_path (str): Model name or path.
            mmap (bool): Open the embedding file via mmap.

        Returns:
            Vector: The embedding table.
        """
        from softmatcha.utils import gensim as gensim_utils

        save_dir = gensim_utils.download_gensim_model(name_or_path)
        embed = np.load(
            os.path.join(save_dir, "embedding.npy"), mmap_mode="r" if mmap else None
        )
        return embed
