# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations

from typing import Any

from camel.embeddings.base import BaseEmbedding


class SentenceTransformerEncoder(BaseEmbedding[str]):
    r"""This class provides functionalities to generate text
    embeddings using `Sentence Transformers`.

    References:
        https://www.sbert.net/
    """

    def __init__(
        self,
        model_name: str = "intfloat/e5-large-v2",
        **kwargs,
    ):
        r"""Initializes the: obj: `SentenceTransformerEmbedding` class
        with the specified transformer model.

        Args:
            model_name (str, optional): The name of the model to use.
                (default: :obj:`intfloat/e5-large-v2`)
            **kwargs (optional): Additional arguments of
                :class:`SentenceTransformer`, such as :obj:`prompts` etc.
        """
        from sentence_transformers import SentenceTransformer

        self.model = SentenceTransformer(model_name, **kwargs)

    def embed_list(
        self,
        objs: list[str],
        **kwargs: Any,
    ) -> list[list[float]]:
        r"""Generates embeddings for the given texts using the model.

        Args:
            objs (list[str]): The texts for which to generate the
                embeddings.

        Returns:
            list[list[float]]: A list that represents the generated embedding
                as a list of floating-point numbers.
        """
        from numpy import ndarray

        if not objs:
            raise ValueError("Input text list is empty")
        embeddings = self.model.encode(
            objs, normalize_embeddings=True, **kwargs
        )
        assert isinstance(embeddings, ndarray)
        return embeddings.tolist()

    def get_output_dim(self) -> int:
        r"""Returns the output dimension of the embeddings.

        Returns:
            int: The dimensionality of the embeddings.
        """
        output_dim = self.model.get_sentence_embedding_dimension()
        assert isinstance(output_dim, int)
        return output_dim
