import random
import openai
from openai import OpenAI
import os
import time
import json
import datetime

class LLM:
    def __init__(self, mode='openai') -> None:
        if mode == 'openchat':
            os.makedirs('logs', exist_ok=True)
            self.call_llm = self.call_llm_openchat
        elif mode == 'openai':
            api_key_list = [

            ]
            self.agent_big = gpt_agent(random.choice(api_key_list), api_key_list, model_name='chatgpt-4o-latest')
            self.agent_small = gpt_agent(random.choice(api_key_list), api_key_list, model_name='gpt-4o-mini')
            self.call_llm = self.call_llm_openai

    def call_llm_openchat(self, prompt, big_model=None, temperature=None):
        timing = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        prompt = prompt.replace("\\", "\\\\").replace("\n", "\\n").replace("\t", "\\t").replace('"', '\\"').replace("'", '\\"')
        command_before_prompt = 'curl http://localhost:18888/v1/chat/completions -H "Content-Type: application/json" -d \'{"model": "openchat_3.5", "messages": [{"role": "user", "content": "'
        command_after_prompt = '"}]}\' > ' + f'logs/{timing}.json'
        command = command_before_prompt + prompt + command_after_prompt
        os.system(command)
        f = open(f"logs/{timing}.json")
        data = json.load(f)
        f.close()
        os.system(f'rm logs/{timing}.json')
        return data["choices"][0]["message"]["content"]
    
    def call_llm_openai(self, prompt, big_model=False, temperature=0.0):
        if big_model:
            return self.agent_big.ask(prompt)
        else:
            return self.agent_small.ask(prompt)

class gpt_agent():

    def __init__(self, api_key:str, api_key_list, model_name="gpt-3.5-turbo"):
        # import openai
        # openai.api_base = "https://api.ai-gaochao.cn/v1"
        # openai.api_base = "https://api.shubiaobiao.cn/v1"
        # openai.api_key = api_key # a key string
        self.client = OpenAI(api_key=api_key, base_url="https://api.shubiaobiao.cn/v1")
        self.api_key = api_key
        self.ask_call_cnt = 0
        self.ask_call_cnt_sup = 100000
        self.model_name = model_name
        self.api_key_list = api_key_list

    def ask(self, question, temperature=0.0, stop=None) -> str:
        res = "No answer!"
        self.ask_call_cnt = self.ask_call_cnt + 1
        if self.ask_call_cnt > self.ask_call_cnt_sup:
            print("======> Achieve call count limit, Return!")
            self._random_key()
            self.ask_call_cnt = 0
            return res

        messages = [{"role": "user", "content": question}]
        try:
            rsp = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=temperature,
                #stop=stop
            )
            # res = rsp.get("choices")[0]["message"]["content"]
            res = rsp.choices[0].message.content
            self.ask_call_cnt = 0
        # except openai.error.AuthenticationError as e:
        except openai.AuthenticationError as e:
            self._random_key()
            # print("======> openai.error.AuthenticationError", e)
            print("======> openai.AuthenticationError", e)
        # except openai.error.RateLimitError as e:
        except openai.RateLimitError as e:
            print(f"======> {self.api_key} <===== \nAchieve ChatGPT rate limit, sleep!", e)
            self._random_key()
            time.sleep(10)
            return self.ask(question)
        # except openai.error.ServiceUnavailableError:
        except openai.APITimeoutError:
            # print('======> Service unavailable error: will retry after 10 seconds')
            print('======> Timeout error: will retry after 10 seconds')
            self._random_key()
            time.sleep(10)
            return self.ask(question)
        except Exception as e:
            print("======> Exception occurs!", e)
            self._random_key()
            if "HTTPSConnectionPool" in str(e.error):
                time.sleep(60)
        return res

    def _random_key(self) -> None:
        self.api_key = random.choice(self.api_key_list)
        openai.api_key = self.api_key


class deepseek_agent():

    def __init__(self, api_key:str, api_key_list, model_name="deepseek-chat"):
        # import openai
        openai.api_base = "https://api.deepseek.com"
        openai.api_key = api_key # a key string
        self.api_key = api_key
        self.ask_call_cnt = 0
        self.ask_call_cnt_sup = 100000
        self.model_name = model_name
        self.api_key_list = api_key_list

    def ask(self, question, temperature=0.0, stop=None) -> str:
        res = "No answer!"
        self.ask_call_cnt = self.ask_call_cnt + 1
        if self.ask_call_cnt > self.ask_call_cnt_sup:
            print("======> Achieve call count limit, Return!")
            self._random_key()
            self.ask_call_cnt = 0
            return res

        messages = [{"role": "user", "content": question}]
        try:
            rsp = openai.ChatCompletion.create(
                model=self.model_name,
                messages=messages,
                # temperature=temperature,
                stop=stop
            )
            res = rsp.get("choices")[0]["message"]["content"]
            self.ask_call_cnt = 0
        except openai.error.AuthenticationError as e:
            self._random_key()
            print("======> openai.error.AuthenticationError", e)
        except openai.error.RateLimitError as e:
            print(f"======> {self.api_key} <===== \nAchieve ChatGPT rate limit, sleep!", e)
            self._random_key()
            time.sleep(10)
            return self.ask(question)
        except openai.error.ServiceUnavailableError:
            print('======> Service unavailable error: will retry after 10 seconds')
            self._random_key()
            time.sleep(10)
            return self.ask(question)
        except Exception as e:
            print("======> Exception occurs!", e)
            self._random_key()
            if "HTTPSConnectionPool" in str(e.error):
                time.sleep(60)
        return res

    def _random_key(self) -> None:
        self.api_key = random.choice(self.api_key_list)
        openai.api_key = self.api_key

if __name__ == "__main__":
    llm = LLM(mode='openai')
    prompt = "hello"
    answer = llm.call_llm(prompt, big_model=False)
    print(answer)