from abc import ABC, abstractmethod
from typing import List
import requests
import time
import numpy as np
from tqdm import tqdm
from openai import OpenAI

from .config import RetrievalConfig


class EmbeddingService(ABC):
                

    def __init__(self, config: RetrievalConfig):
        self.config = config

    @abstractmethod
    def get_embeddings(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
                              
        pass

    @abstractmethod
    def get_embedding_dim(self) -> int:
                             
        pass


class OpenAIEmbeddingService(EmbeddingService):
                       

    def __init__(self, config: RetrievalConfig):
        super().__init__(config)
        self.client = OpenAI(
            api_key=config.api_key,
            base_url=config.base_url
        )
                   
        self.model_dims = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 3072,
            "text-embedding-ada-002": 1536,
        }

    def get_embeddings(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
                              
        embeddings = []

        if self.config.verbose:
            print("Getting embeddings")
        progress_bar = tqdm(range(0, len(texts), batch_size), desc="Getting embeddings", disable=not self.config.verbose)

        for i in progress_bar:
            batch_texts = texts[i:i + batch_size]

            try:
                response = self.client.embeddings.create(
                    model=self.config.embedding_model,
                    input=batch_texts
                )

                batch_embeddings = [item.embedding for item in response.data]
                embeddings.extend(batch_embeddings)

                              
                time.sleep(0.05)

            except Exception as e:
                if self.config.verbose:
                    print(f"Error getting embeddings for batch {i}: {e}")
                embeddings.extend([[0.0] * self.get_embedding_dim()] * len(batch_texts))

        return np.array(embeddings)

    def get_embedding_dim(self) -> int:
                             
        return self.model_dims.get(self.config.embedding_model, 1536)


class SiliconFlowEmbeddingService(EmbeddingService):
                         

    def __init__(self, config: RetrievalConfig):
        super().__init__(config)
        self.headers = {
            "Authorization": f"Bearer {config.siliconflow_api_key}",
            "Content-Type": "application/json"
        }
                   
        self.model_dims = {
            "BAAI/bge-large-en-v1.5": 1024,
            "BAAI/bge-m3": 1024,
            "Pro/BAAI/bge-m3": 1024,
            "Qwen/Qwen3-Embedding-0.6B": 1024,             
        }

                     
        self.max_batch_sizes = {
            "BAAI/bge-large-en-v1.5": 100,
            "BAAI/bge-m3": 64,
            "Pro/BAAI/bge-m3": 64,                
            "Qwen/Qwen3-Embedding-0.6B": 100,
        }

    def get_embeddings(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
                              
        embeddings = []

                    
        max_batch_size = self.max_batch_sizes.get(self.config.embedding_model, 100)
        actual_batch_size = min(batch_size, max_batch_size)

        if self.config.verbose:
            if actual_batch_size != batch_size:
                print(f"Using actual batch size: {actual_batch_size}")
        progress_bar = tqdm(
            range(
                0,
                len(texts),
                actual_batch_size),
            desc="Getting embeddings",
            disable=not self.config.verbose)

        for i in progress_bar:
            batch_texts = texts[i:i + actual_batch_size]

            try:
                payload = {
                    "model": self.config.embedding_model,
                    "input": batch_texts
                }

                response = requests.post(
                    f"{self.config.siliconflow_base_url}/embeddings",
                    json=payload,
                    headers=self.headers
                )

                if response.status_code == 200:
                    result = response.json()
                    if 'data' in result and len(result['data']) > 0:
                        batch_embeddings = [item['embedding'] for item in result['data']]
                        embeddings.extend(batch_embeddings)
                    else:
                        raise Exception("Abnormal response data format")
                else:
                    raise Exception(f"API call failed: {response.status_code} - {response.text}")

                              
                time.sleep(0.1)

            except Exception as e:
                if self.config.verbose:
                    print(f"Error getting embeddings for batch {i}: {e}")
                embeddings.extend([[0.0] * self.get_embedding_dim()] * len(batch_texts))

        return np.array(embeddings)

    def get_embedding_dim(self) -> int:
                             
        return self.model_dims.get(self.config.embedding_model, 1024)


def create_embedding_service(config: RetrievalConfig) -> EmbeddingService:
                  
    if config.embedding_service.lower() == "siliconflow":
        return SiliconFlowEmbeddingService(config)
    else:
        return OpenAIEmbeddingService(config)
