import time
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed

class LLMModel():
    # "https://api.deepseek.com/v1", "https://api.openai.com/v1/", "https://dashscope.aliyuncs.com/compatible-mode/v1"
    def __init__(self, api_key, embedding_key=None, model="gpt-3.5-turbo", base_url="https://api.openai.com/v1/", use_web_search=False, enable_thinking=None):
        self.api_key = api_key
        self.model = model
        self.client = OpenAI(api_key=self.api_key, base_url=base_url)
        if embedding_key is None:
            embedding_key = api_key
        self.embedding_client = OpenAI(api_key=embedding_key, base_url="https://api.openai.com/v1/")
        self.use_web_search = use_web_search
        self.enable_thinking = enable_thinking
    
    def LLM_response(self, prompt, gen_kwargs={}, model=None, full_response=False, enable_thinking=None):
        if model is None:
            model = self.model
        if enable_thinking is None:
            enable_thinking = self.enable_thinking

        if type(prompt) == str:
            input_messages = [
                {"role": "user", "content": prompt}
                ]
        elif type(prompt) == list:
            input_messages = prompt
        else:
            print("prompt must be a string or a list of messages, current type: ", type(prompt))
            raise ValueError("prompt must be a string or a list of messages")
        
        if self.use_web_search:
            model = "gpt-4o-search-preview"
            gen_kwargs["web_search_options"] = {}

        if self.enable_thinking is not None:
            if enable_thinking == True:
                gen_kwargs["extra_body"] = {"enable_thinking": True}
            elif enable_thinking == False:
                gen_kwargs["extra_body"] = {"enable_thinking": False}

        # print(model)
        # print("input_messages: ", input_messages)
        # print(gen_kwargs)

        completion = self.client.chat.completions.create(
            model=model,
            messages=input_messages,
            **gen_kwargs
            )

        if full_response:
            return completion
        
        return completion.choices[0].message.content

    def LLM_response_async(self, prompts, gen_kwargs={}, model=None, max_workers=20, full_response=False, max_retries=None, current_retry=0):
        if model is None:
            model = self.model

        all_success = True
        
        # check if prompts is a list of strings or list of list of messages
        if not isinstance(prompts, list) or not all(isinstance(p, (str, list)) for p in prompts):
            raise ValueError("prompts must be a list of strings or a list of list of messages")

        results = [None] * len(prompts)
        
        future_to_index = {}

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            for index, prompt in enumerate(prompts):
                future = executor.submit(
                    self.LLM_response,
                    prompt=prompt,
                    gen_kwargs=gen_kwargs,
                    model=model,
                    full_response=full_response
                )
                future_to_index[future] = index

            for future in as_completed(future_to_index):
                original_index = future_to_index[future] # get the original index of the prompt
                try:
                    data = future.result()
                    results[original_index] = (True, data, prompts[original_index])
                except Exception as exc:
                    print(f'Prompt at index {original_index} generated an exception: {exc}')
                    results[original_index] = (False, exc if full_response else f"Error: {str(exc)}", prompts[original_index])
                    all_success = False

        if isinstance(max_retries, int) and max_retries > 0 and not all_success and current_retry < max_retries:
            print(f"Retry attempt {current_retry + 1} of {max_retries}")
            # Collect prompts that failed
            retry_items_index = [i for i, (success, _, _) in enumerate(results) if not success]
            retry_prompts = [prompts[i] for i in retry_items_index]

            # Recursive call for retries
            retry_results, retry_all_success = self.LLM_response_async(
                prompts=retry_prompts,
                gen_kwargs=gen_kwargs,
                model=model,
                max_workers=max_workers,
                full_response=full_response,
                max_retries=max_retries,
                current_retry=current_retry + 1
            )

            # Update the original results with retry results
            for idx, original_idx in enumerate(retry_items_index):
                results[original_idx] = retry_results[idx]

            all_success = retry_all_success

        return results, all_success
    
    def Embedding_response(self, input_texts, model="text-embedding-3-large"):
        if type(input_texts) == str:
            input_texts = [input_texts]
        elif type(input_texts) == list:
            if not all(isinstance(text, str) for text in input_texts):
                raise ValueError("All items in input_texts list must be strings")
        else:
            print("input_texts must be a string or a list of strings, current type: ", type(input_texts))
            raise ValueError("input_texts must be a string or a list of strings")
        
        embedding_client = self.embedding_client

        response = embedding_client.embeddings.create(
            model=model,
            input=input_texts
        )

        embeddings = [data.embedding for data in response.data]

        return embeddings