from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from typing import Literal, Any, Iterable
from pathlib import Path
import json
import threading
from collections import deque


class ChatClient:

    def __init__(self, api_key: str, base_url: str,
                 model: str | None = None,
                 **options):
 
        self.api_key = api_key
        self.base_url = base_url
        self.model = model
        self.options = options

    def chat(self, messages: list, model: str | None = None, **kwargs):
        """Get a chat completion from the OpenAI API."""

        if model is None:
            model = self.model
        if model is None:
            raise ValueError("Model must be specified if not set in the constructor.")

        client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key
        )

        response = client.chat.completions.create(
            model=model,
            messages=messages,
            **(self.options | kwargs)
        )
        assert isinstance(response, ChatCompletion), f"Invalid response type: {type(response)}"
        return response
    
    def completion(self, prompt: str, model: str | None = None, **kwargs):
        """Get a prompt completion from the OpenAI API."""

        return self.chat([{"role": "user", "content": prompt}], model=model, **kwargs)
    
    def completions(
        self, 
        prompts: list[str],
        model: str | None = None,
        num_threads: int = 1,
        **kwargs
    ) -> Iterable[ChatCompletion | Exception]:
        """Get completions for a list of prompts using multiple threads.
        If num_threads is 1, it will use the same method as completion().
        Exceptions that occur during API calls will be yielded as well.
        """

        # a list of prompts that are being processed by threads
        results: dict[int, Any] = {}
        
        next_prompt_idx: int = 0  # next prompt to process
        next_result_idx: int = 0  # next result to yield

        if num_threads == 1:
            for i in range(len(prompts)):
                try:
                    output = self.completion(prompts[i], model=model, **kwargs)
                except Exception as e:
                    output = e
            yield output


        def worker_fn(i: int):
            try:
                output = self.completion(prompts[i], model=model, **kwargs)
            except Exception as e:
                output = e
            results[i] = output
        
        workers: dict[int, threading.Thread] = {}

        while next_prompt_idx < len(prompts) or len(workers) > 0:

            # remove terminated workers
            for i in list(workers.keys()):
                if not workers[i].is_alive():
                    workers[i].join()
                    del workers[i]
            
            # if we have more prompts to process and we have space for more workers
            while next_prompt_idx < len(prompts) and len(workers) < num_threads:
                # start a new worker thread
                workers[next_prompt_idx] = threading.Thread(
                    target=worker_fn,
                    args=(next_prompt_idx,),
                )
                workers[next_prompt_idx].start()
                next_prompt_idx += 1

            # if next result is ready, yield it
            if next_result_idx in results:
                yield results[next_result_idx]
                del results[next_result_idx]
                next_result_idx += 1

        # yield remaining results
        for i in range(next_result_idx, len(prompts)):
            if i in results:
                yield results[i]
                del results[i]
            else:
                yield AssertionError("Result not found in results dictionary. (Might be a bug)")


type _LLMName = Literal["gpt-4o", "gpt-4o-mini", "o1", "o1-mini", "o3", "o3-mini",
                        "DeepSeek-R1"]


def build_chat_client(
    model_name: _LLMName | str,
    config_file: str | Path,
    **options,
):
    
    with open(config_file, "rt") as f:
        model_config = json.load(f)
    
    if not isinstance(model_config, dict):
        raise ValueError(f"Invalid config file: {config_file}")
    
    if model_name in ("gpt-4o", "gpt-4o-mini", "o1", "o1-mini", "o3", "o3-mini"):
        base_url = model_config["openai-base-url"]
        api_key = model_config["openai-api-key"]
        client = ChatClient(
            api_key=api_key,
            base_url=base_url,
            model=model_name,
            **options
        )
    elif model_name in ("DeepSeek-R1",):
        base_url = model_config["openai-base-url"]
        api_key = model_config["openai-api-key"]
        client = ChatClient(
            api_key=api_key,
            base_url=base_url,
            model="deepseek-ai/"+model_name,
            **options
        )
    else:
        raise ValueError(f"Invalid model name: {model_name}")
    
    return client
