import openai
import threading
import time
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor

class GPT_Client:
    def __init__(self, url, api_key: str, model: str = "gpt-4o-mini", max_workers: int = 16):
        self.api_key = api_key
        self.model = model
        openai.api_key = self.api_key
        openai.base_url = url
        self.start_time = time.time()  
        self.usage = {
            "completion_tokens": 0,
            "prompt_tokens": 0,
            "total_tokens": 0
        }
        self.executor = ThreadPoolExecutor(max_workers=max_workers) 
        self.lock = threading.Lock()

    def get_elapsed_time(self) -> float:
        return time.time() - self.start_time

    def fetch_response(self, prompt: str) -> str:
        response = openai.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.5,
            timeout=60,
            max_tokens=512
        )
        self.update_usage(response.usage)
        return response.choices[0].message.content

    def update_usage(self, usage: Dict[str, int]):
        with self.lock:
            self.usage['completion_tokens'] += usage.completion_tokens
            self.usage['prompt_tokens'] += usage.prompt_tokens
            self.usage['total_tokens'] += usage.total_tokens

    def get_usage(self) -> Dict[str, int]:
        return self.usage

    def ask_question(self, prompt: str) -> str:
        return self.fetch_response(prompt)

    def ask_questions(self, prompts: List[str]) -> List[str]:
        futures = {self.executor.submit(self.ask_question, prompt): prompt for prompt in prompts}
        
        results = []
        for future in futures:
            result = future.result() 
            #print(result)
            results.append(result)  

        return results

if __name__ == "__main__":
    # prompt1 = """
    # who are you?
    # """
    # prompt2 = """
    # how can you help?
    # """
    # model = "gpt-4o-mini"
    # api_key = ""
    # client = GPT_Client(api_key, model, max_workers=16)  # 最大线程数，数值太大会超过openai的速率限制
    # responses = client.ask_questions([prompt1, prompt2])  # 使用多线程处理问题
    # for response in responses:
    #     print(response.replace("**", "").replace("### ", "").replace("\n\n", "\n"))
    #     print("\n")
    # print("Token usage:", client.get_usage())
    # print(f"Object has been alive for {client.get_elapsed_time()} seconds")
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--api_key", type = str, required=True)
    args = parser.parse_args()
    api_key = args.api_key
    openai.api_key = api_key
    client = openai.OpenAI(api_key=api_key)
    models = client.models.list()
    print(client.models.list())
    