from typing import List
from llama_index.core.embeddings import BaseEmbedding

from .embedding_service import EmbeddingService


class SiliconFlowEmbeddingAdapter(BaseEmbedding):    

    def __init__(self, embedding_service: EmbeddingService):
           
        super().__init__(
            embed_batch_size=100,
            callback_manager=None
        )
                                           
        object.__setattr__(self, '_embedding_service', embedding_service)

    @classmethod
    def class_name(cls) -> str:
        return "SiliconFlowEmbeddingAdapter"

    async def _aget_query_embedding(self, query: str) -> List[float]:
                      
        return self._get_query_embedding(query)

    def _get_query_embedding(self, query: str) -> List[float]:
                    
                                
        embeddings = self._embedding_service.get_embeddings([query], batch_size=10)
        return embeddings[0].tolist()

    def _get_text_embedding(self, text: str) -> List[float]:
                      
                                
        embeddings = self._embedding_service.get_embeddings([text], batch_size=10)
        return embeddings[0].tolist()

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
                        
        return self._get_text_embeddings(texts)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
                      
        embeddings = self._embedding_service.get_embeddings(texts, batch_size=self.embed_batch_size)
        return [embedding.tolist() for embedding in embeddings]
