from typing import Dict, List
from tenacity import retry, stop_after_attempt,stop_after_delay, wait_exponential, retry_if_exception_type

from src.configs.provider import OpenAIEmbedderConfig, AzureEmbedderConfig
from src.providers.base import BaseEmbedder

import logging

logger = logging.getLogger(__name__)

def log_attempt(retry_state):
    logger.info(f"retry_state.attempt_number: {retry_state.attempt_number}")

class OpenAIEmbedder(BaseEmbedder):
    def __init__(self, config: OpenAIEmbedderConfig):
        from openai import OpenAI
        self.config = config
        self.client = OpenAI(
            api_key=config.api_key, 
            base_url=config.base_url
        )
    
    @retry(
        stop=(stop_after_delay(300) | stop_after_attempt(20)),
        wait=wait_exponential(multiplier=2, min=4, max=30),
        retry=retry_if_exception_type(Exception),
        before_sleep=log_attempt,
    )
    def embed(self, texts: List[str]):
        response = self.client.embeddings.create(
            model = self.config.model_name_or_path,
            input = texts
        )
        return [r.embedding for r in response.data]

class AzureEmbedder(BaseEmbedder):
    def __init__(self, config: AzureEmbedderConfig):
        from openai import AzureOpenAI
        self.config = config
        self.client = AzureOpenAI(
            api_key=config.api_key, 
            azure_endpoint=config.base_url, 
            api_version=config.api_version
        )
    @retry(
        stop=(stop_after_delay(300) | stop_after_attempt(20)),
        wait=wait_exponential(multiplier=2, min=4, max=30),
        retry=retry_if_exception_type(Exception),
        before_sleep=log_attempt,
    )
    def embed(self, texts: List[str]):
        response = self.client.embeddings.create(
            model = self.config.model_name_or_path,
            input = texts
        )
        return [r.embedding for r in response.data]
        