import openai
import threading
import time
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor

class GPT_Client:
    def __init__(self, url, api_key: str, model: str = "gpt-4o-mini", max_workers: int = 16, max_tokens=None):
        self.api_key = api_key
        self.model = model
        openai.api_key = self.api_key
        openai.base_url = url
        self.start_time = time.time()  # 记录实例化时间
        self.client = openai.OpenAI(
            base_url=url,
            api_key=api_key
        )
        self.usage = {
            "completion_tokens": 0,
            "prompt_tokens": 0,
            "total_tokens": 0
        }
        self.executor = ThreadPoolExecutor(max_workers=max_workers)  # 设置最大同时运行的线程数
        self.lock = threading.Lock()
        self.max_tokens = max_tokens

    def get_elapsed_time(self) -> float:
        """返回对象实例化后的存活时间（秒）"""
        return time.time() - self.start_time

    def fetch_response(self, prompt: str) -> str:
        if self.max_tokens == None:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.5,
                timeout=60,
                # max_tokens=512
            )
        else:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.5,
                timeout=60,
                max_tokens=self.max_tokens
            )
        self.update_usage(response.usage)
        return response.choices[0].message.content

    def update_usage(self, usage: Dict[str, int]):
        with self.lock:
            self.usage['completion_tokens'] += usage.completion_tokens
            self.usage['prompt_tokens'] += usage.prompt_tokens
            self.usage['total_tokens'] += usage.total_tokens

    def get_usage(self) -> Dict[str, int]:
        return self.usage

    def ask_question(self, prompt: str) -> str:
        return self.fetch_response(prompt)

    def ask_questions(self, prompts: List[str]) -> List[str]:
        # 使用 zip 将 prompts 和 futures 组合在一起
        futures = {self.executor.submit(self.ask_question, prompt): prompt for prompt in prompts}
        
        # 根据 futures 的顺序返回结果
        results = []
        for future in futures:
            result = future.result()  # 获取结果
            #print(result)
            results.append(result)  # 将结果添加到列表中

        return results

if __name__ == "__main__":
    # prompt1 = """
    # who are you?
    # """
    # prompt2 = """
    # how can you help?
    # """
    # model = "gpt-4o-mini"
    # api_key = ""
    # client = GPT_Client(api_key, model, max_workers=16)  # 最大线程数，数值太大会超过openai的速率限制
    # responses = client.ask_questions([prompt1, prompt2])  # 使用多线程处理问题
    # for response in responses:
    #     print(response.replace("**", "").replace("### ", "").replace("\n\n", "\n"))
    #     print("\n")
    # # 打印token用量
    # print("Token usage:", client.get_usage())
    # # 打印运行时间
    # print(f"Object has been alive for {client.get_elapsed_time()} seconds")
    openai.api_key = "sk-proj-yV8xYTr5Med19Lvf6EJJT3BlbkFJNXCvb1CR3t5dkvZMefb2"
    client = openai.OpenAI(api_key="sk-proj-yV8xYTr5Med19Lvf6EJJT3BlbkFJNXCvb1CR3t5dkvZMefb2")
    models = client.models.list()
    print(client.models.list())
    