import json
import random
from openai import OpenAI
import openai
from fastchat.conversation import get_conv_template
import argparse
import requests
import time
from tqdm import tqdm
import os

class APIModel():
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.api_url=""
        self.api_key=""
        self.headers = {
            "accept": "application/json",
            "content-type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        
        
    def generate_response(self, messages, **kwargs):
        
        data={
            "model": self.model_name,
            "messages": [
                {
                    "role": "user",
                    "content": messages
                }
            ]
        }

        response = requests.post(self.api_url, json=data, headers=self.headers)
        print(response.text)
        result=response.json()['choices'][0]['message']['content']
        
        time.sleep(3)
        return result

class OpenaiModel():
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.api_key = ""

        self.client = OpenAI(
        api_key=self.api_key,
        base_url=''
        )  

        self.conversation = get_conv_template('chatgpt')

    def generate_response(self, messages, **kwargs):
        for _ in range(8): 
            try:
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=[
                        {"role": "user", "content": messages},
                    ],
                    temperature=0,
                    **kwargs
                )
                result=response.choices[0].message.content
                # time.sleep(5)
                break
            except openai.APIConnectionError as e:
                print("The server could not be reached")
                print(e.__cause__)  # an underlying Exception, likely raised within httpx.
            except openai.RateLimitError as e:
                print("A 429 status code was received; we should back off a bit.")
            except openai.APIStatusError as e:
                print("Another non-200-range status code was received")
                print(e.status_code)
                print(e.response)
            time.sleep(4)
        
        return result

def getResponse(prompt, model_name):
    if "llama-2" in model_name:
        model="meta-llama/Llama-2-7b-chat-hf"
    elif "llama-3" in model_name:
        model="meta-llama/Llama-3-8b-chat-hf"
    elif "vicuna" in model_name:
        model="lmsys/vicuna-7b-v1.5"
    elif "mistral" in model_name:
        model="mistralai/Mistral-7B-Instruct-v0.1"
    elif "qwen" in model_name:
        model="Qwen/Qwen1.5-7B-Chat"
    elif "gemma" in model_name:
        model="google/gemma-7b-it"
    elif "claude" in model_name:
        model="claude-3-sonnet-20240229"
    elif "gpt-3.5" in model_name:
        model = "gpt-3.5-turbo"
    elif "gpt-4" in model_name:
        model = "gpt-4"
    
    if "gpt" not in model:
        model_instance = APIModel(model)
        response = model_instance.generate_response(prompt)
    else:
        model_instance = OpenaiModel(model)
        response = model_instance.generate_response(prompt)

    return response