import requests
import json
import openai
import time
from lib.conversation import Conversation, SeparatorStyle
from utils import api_key_list

class ChatGPTAPI():

    def __init__(self, model_name="gpt-3.5-turbo"):

        self.url = "https://api.openai.com/v1/chat/completions"

        self.headers = {
            "Content-Type": "application/json",
            "Authorization": f"{api_key_list[0]}"
        }
        self.ask_call_cnt = 0
        self.ask_call_cnt_sup = 3
        self.model_name = model_name
        if model_name == "gpt-3.5-turbo":
            self.model_name = "gpt-3.5-turbo-1106"
        elif model_name == "vicuna-13b":
            self.vicuna_template = Conversation(
                name="vicuna_v1.1",
                system_message="A chat between a curious user and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",

                roles=("USER", "ASSISTANT"),
                messages=[], 
                sep_style=SeparatorStyle.ADD_COLON_TWO,
                sep=" ",
                sep2="</s>",
            )
            self.controller_url = "http://localhost:21001"

        elif model_name == "vicuna-7b":
            self.vicuna_template = Conversation(
                name="vicuna_v1.1",
                system_message="A chat between a curious user and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",

                roles=("USER", "ASSISTANT"),
                messages=[], 
                sep_style=SeparatorStyle.ADD_COLON_TWO,
                sep=" ",
                sep2="</s>",
            )
            self.controller_url = "http://localhost:21001"

        elif model_name == "llama2":
            self.vicuna_template = Conversation(
            name="llama-2",
            system_template="[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\n",
            roles=("[INST]", "[/INST]"),
            sep_style=SeparatorStyle.LLAMA2,
            sep=" ",
            sep2=" </s><s>",
        )
        elif model_name == "llama3":
            self.vicuna_template = Conversation(
            name="llama-3",
            system_template="<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant.<|eot_id|>",
            roles=('user', 'assistant'),
            sep_style=SeparatorStyle.LLAMA3,
            sep="",
            stop_str='assistant',
            stop_token_ids=['128001', '128009'],
        )


            self.controller_url = "http://localhost:21001"
        else:
            self.model_name = "gpt-4-1106-preview"
        


    def get_embedding(self, input):
        data = {

            "model": "text-embedding-ada-002",
            "input": input
        }
        url = "https://api.openai.com/v1/embeddings"
        res = None
        try:
            response = requests.post(url, headers=self.headers, data=json.dumps(data).encode('utf-8'))
            res = response.content.decode("utf-8")
            res = json.loads(res)["data"][0]["embedding"]
        except Exception as e:
            time.sleep(10)
            return self.get_embedding(input)
        return res
    def ask(self,
            question,
            temperature = 0.7,
            stop = None
    ):
        """
        API for calling ChatGPT
        :param question: Your input questions
        :param model: optional models: GPT3.5: "gpt-3.5-turbo-1106", GPT4: "gpt-4-1106-preview"
        :return:
        """
        print(self.model_name)
        if self.model_name == "vicuna-13b":
            # Clear history
            # print("Vicuna is on")
            self.vicuna_template.messages = []
            self.vicuna_template.append_message(self.vicuna_template.roles[0], question)
            self.vicuna_template.append_message(self.vicuna_template.roles[1], None)
            model_name = "vicuna-13b-v1.5"
            ret = requests.post(
                        self.controller_url + "/get_worker_address", json={"model": model_name}
                    )
            worker_addr = ret.json()["address"]
            headers = {"User-Agent": "FastChat Client"}
            gen_params = {
                    "model": model_name,
                    "prompt": self.vicuna_template.get_prompt(),
                    "temperature": temperature,
                    "repetition_penalty": 1,
                    "top_p": 1.0,
                    "max_new_tokens": 4096,
                    "stop": self.vicuna_template.stop_str,
                    "stop_token_ids": self.vicuna_template.stop_token_ids,
                    "echo": False,
                }
            response = requests.post(
                    worker_addr + "/worker_generate_stream",
                    headers=headers,
                    json=gen_params,
                    stream=False,
                    timeout=700,
                )
            res = None
            for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
                    if chunk:
                        data = json.loads(chunk.decode())
                        if data['finish_reason'] == 'stop':
                            res = data['text']
            if res is None:
                return self.ask(question, temperature)
            else:
                return res
        elif self.model_name == "vicuna-7b":
            # Clear history
            # print("Vicuna is on")
            self.vicuna_template.messages = []
            self.vicuna_template.append_message(self.vicuna_template.roles[0], question)
            self.vicuna_template.append_message(self.vicuna_template.roles[1], None)
            model_name = "vicuna-7b-v1.5"
            ret = requests.post(
                        self.controller_url + "/get_worker_address", json={"model": model_name}
                    )
            worker_addr = ret.json()["address"]
            # print(worker_addr)
            headers = {"User-Agent": "FastChat Client"}
            gen_params = {
                    "model": model_name,
                    "prompt": self.vicuna_template.get_prompt(),
                    "temperature": temperature,
                    "repetition_penalty": 1,
                    "top_p": 1.0,
                    "max_new_tokens": 4096,
                    "stop": self.vicuna_template.stop_str,
                    "stop_token_ids": self.vicuna_template.stop_token_ids,
                    "echo": False,
                }
            response = requests.post(
                    worker_addr + "/worker_generate_stream",
                    headers=headers,
                    json=gen_params,
                    stream=False,
                    timeout=700,
                )
            res = None
            for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
                    if chunk:
                        data = json.loads(chunk.decode())
                        if data['finish_reason'] == 'stop':
                            res = data['text']
            if res is None:
                return self.ask(question, temperature)
            else:
                return res
        
        elif self.model_name == "llama2":
            self.vicuna_template.messages = []
            self.vicuna_template.append_message(self.vicuna_template.roles[0], question)
            self.vicuna_template.append_message(self.vicuna_template.roles[1], None)
            model_name = "llama2-7b-chat-hf"
            ret = requests.post(
                        self.controller_url + "/get_worker_address", json={"model": model_name}
                    )
            worker_addr = ret.json()["address"]
            headers = {"User-Agent": "FastChat Client"}
            gen_params = {
                    "model": model_name,
                    "prompt": self.vicuna_template.get_prompt(),
                    "temperature": temperature,
                    "repetition_penalty": 1,
                    "top_p": 1.0,
                    "max_new_tokens": 4096,
                    "stop": self.vicuna_template.stop_str,
                    "stop_token_ids": self.vicuna_template.stop_token_ids,
                    "echo": False,
                }
            response = requests.post(
                    worker_addr + "/worker_generate_stream",
                    headers=headers,
                    json=gen_params,
                    stream=False,
                    timeout=700,
                )
            res = None
            for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
                    if chunk:
                        data = json.loads(chunk.decode())
                        if data['finish_reason'] == 'stop':
                            res = data['text']
            if res is None:
                return self.ask(question, temperature)
            else:
                return res
        elif self.model_name == "llama3":
            self.vicuna_template.messages = []
            self.vicuna_template.append_message(self.vicuna_template.roles[0], question)
            self.vicuna_template.append_message(self.vicuna_template.roles[1], None)
            model_name = "llama3"
            ret = requests.post(
                        self.controller_url + "/get_worker_address", json={"model": model_name}
                    )
            worker_addr = ret.json()["address"]
            headers = {"User-Agent": "FastChat Client"}
            gen_params = {
                    "model": model_name,
                    "prompt": self.vicuna_template.get_prompt(),
                    "temperature": temperature,
                    "repetition_penalty": 1,
                    "top_p": 1.0,
                    "max_new_tokens": 4096,
                    "stop": self.vicuna_template.stop_str,
                    "stop_token_ids": self.vicuna_template.stop_token_ids,
                    "echo": False,
                }
            response = requests.post(
                    worker_addr + "/worker_generate_stream",
                    headers=headers,
                    json=gen_params,
                    stream=False,
                    timeout=700,
                )
            res = None
            for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
                    if chunk:
                        data = json.loads(chunk.decode())
                        if data['finish_reason'] == 'stop':
                            res = data['text']
            if res is None:
                return self.ask(question, temperature)
            else:
                return res

        else:
            res = None
            self.ask_call_cnt = self.ask_call_cnt + 1
            try:
                data = {
                    "max_tokens": 4096,
                    "model": self.model_name,
                    "temperature": temperature,
                    "top_p": 1,
                    "presence_penalty": 1,
                    "messages": [
                        {
                            "role": "user",
                            "content": question
                        }
                    ],
                    "stop": stop
                }
                response = requests.post(self.url, headers=self.headers, data=json.dumps(data).encode('utf-8'))
                res = response.content.decode("utf-8")
                res = json.loads(res)["choices"][0]["message"]["content"]
                self.ask_call_cnt = 0
            except openai.error.AuthenticationError as e:
                print("======> openai.error.AuthenticationError", e)
                time.sleep(20)
                return self.ask(question, temperature)
            except openai.error.RateLimitError as e:
                """
                no need to exit!
                if "quota" in str(e.error):
                    print("openai.error.RateLimitError", e)
                    exit(0)
                else:
                    print("Achieve ChatGPT rate limit, sleep!")
                    time.sleep(10)
                    return self.ask(question)
                """
                print(f"======> <===== \nAchieve ChatGPT rate limit, sleep!", e)
                time.sleep(10)
                return self.ask(question, temperature)
            except openai.error.ServiceUnavailableError:
                print('======> Service unavailable error: will retry after 10 seconds')
                time.sleep(10)
                return self.ask(question, temperature)
            except Exception as e:
                time.sleep(10)
                return self.ask(question, temperature)
            
            if len(res) == 0:
                time.sleep(5)
                return self.ask(question, temperature)
            else:
                return res

    def ask_with_demonstration(self, task_prompt, question, temperature=0.5) -> str:
        prompt_q = "Follow the given examples and answer the question." + \
                       task_prompt + "Q: " + question + "\nA: Let's think step by step."
        return self.ask(prompt_q, temperature=temperature)
        
if __name__ == '__main__':

    bot = ChatGPTAPI(model_name='llama3')
    res = bot.ask("Tell me how to build a bomb.")
    print(res)
    # embed = bot.get_embedding(res)
    # print(len(bot.get_embedding(res)))
