import openai
import os
import numpy as np
import time
import requests

openai.organization = os.getenv("OPENAI_ORG")
openai.api_key = os.getenv("OPENAI_AK_KEY")
llama_path = os.getenv("LLAMA_PATH")    

class OpenAILLM(object):
    def __init__(self, model, temp=0, seed=None):
        self.model = model
        self.temp = int(temp)
        if self.model == 'gpt-3.5-turbo':
            self.inner_model='gpt-3.5-turbo-0613'
            self.name='gpt35'
        elif self.model == 'gpt-4':
            self.inner_model='gpt-4-0613'
            self.name='gpt4'
            
        self.long_name=self.name+f"_t={self.temp}"

    def predict(self, system, prompt):
        messages = []
        if system is not None:
            messages.append({"role":"system",
                             "content":system})
        messages.append({"role":"user",
                         "content":prompt})
        while True:
            try:
                response = openai.ChatCompletion.create(
                    model=self.inner_model,
                    messages=messages,
                    temperature=self.temp,
                    request_timeout=60
                )
                output = response['choices'][0]['message']['content']
                if self.name == 'gpt4':
                    time.sleep(4)
                return(output)
            except (openai.error.RateLimitError, openai.error.Timeout, openai.error.ServiceUnavailableError) as err:
                print(err, flush=True)
                print("[Debug] openai error, sleeping for 30 seconds", flush=True)
                time.sleep(30)

class LlamaLLM(object):
    def __init__(self, model, temp=0, seed=None):
        self.model = model
        self.temp = int(temp)
        if self.model == 'Llama-2-7b-hf':
            self.model = llama_path+"Llama-2-7b-chat-hf"
            self.name='llama7b'
        elif self.model == 'Llama-2-13b-hf':
            self.model = llama_path+"Llama-2-13b-chat-hf"
            self.name='llama13b'
        elif self.model == 'Llama-2-70b-hf':
            self.model = llama_path+"Llama-2-70b-chat-hf"
            self.name='llama70b'
            
        self.long_name=self.name+f"_t={self.temp}"

    def predict(self, system, prompt):
        messages = []
        if system is not None:
            messages.append({"role":"system",
                             "content":system})
        messages.append({"role":"user",
                         "content":prompt})

        data = {
            "history": messages,
            "message": prompt,
            "model": self.model,
            "temperature": self.temp
            }

        response = requests.post("http://127.0.0.1:5000/generate", json=data)
        if response.status_code == 200:
            return response.json().get("response")
        else:
            raise Exception(f"Error from server: {response.json().get('error')}")
