import os
import openai
import logging
import jsonlines
import time
import random
import requests
import json
from tqdm import tqdm


def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None, API_KEY=None):
    API_ENDPOINT = "https://api.openai.com/v1/chat/completions"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}",
    }

    data = {
        "model": model,
        "messages": messages,
        "temperature": temperature,
    }

    if max_tokens is not None:
        data["max_tokens"] = max_tokens

    response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
    return response.json()["choices"][0]["message"]["content"]
    # if response.status_code == 200:
    #     return response.json()["choices"][0]["message"]["content"]
    # else:
    #     raise Exception(f"Error {response.status_code}: {response.text}")


def openai_set(openai_org, openai_key, model_name, one_prompt):
    if openai_org is not None:
        openai.organization = openai_org
        logging.warning(f"Switching to organization: {openai_org} for OpenAI API key.")
    openai.api_key = openai_key
    messages = [
        {
            "role": "system",
            "content": "You are ChatGPT, a large language model trained by OpenAI. "
        },
        {
            "role": "user",
            "content": one_prompt
        }
    ]
    if model_name == "text-davinci-003" or model_name == "text-davinci-002":
        response = openai.Completion.create(
            engine=model_name,
            prompt=one_prompt,
            temperature=0.5,
            top_p=0.5,
            n=1,
            max_tokens=512,
        )
        output = response.choices[0].text.strip()

    elif model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
        # for n_attempts_remaining in range(20, 0, -1):
        try:
            response = openai.ChatCompletion.create(
                model=model_name,
                messages=messages,
                temperature=1,
                top_p=0.5,
                n=1,
                max_tokens=256,
            )
        except (
                openai.error.RateLimitError,
                openai.error.ServiceUnavailableError,
                openai.error.APIError,
                openai.error.APIConnectionError,
                openai.error.Timeout,
                openai.error.TryAgain,
                openai.error.OpenAIError,
        ) as e:
            print(e)
        output = response['choices'][0]['message']['content']
    # elif model_name == "gpt-4":
    #     response = generate_chat_completion(
    #         messages=messages,
    #         model=model_name,
    #         temperature=0.5,
    #         max_tokens=256,
    #         API_KEY=openai_key
    #     )
    #     output = response
    return output


def get_all_prompt_file(filePath):
    file_name_ls = os.listdir(filePath)
    full_filepaths = []
    for filename in file_name_ls:
        full_filepath = filePath + '/' + filename
        full_filepaths.append(full_filepath)
    return full_filepaths

def read_llm_output(full_filepath):
    with open(full_filepath, mode='r', encoding='utf-8') as genf:
        instructions = []
        for item in jsonlines.Reader(genf):
            instruction = item['instruction']
            instructions.append(instruction)
    return instructions

filePath = 'generated/biasattack'
# openai_org = "org-k4nHgTVCNwh0erMfMLMjewXC"
openai_org = None
openai_key_list = ["***"]
model_name_choices = ["gpt-4"]#"gpt-3.5-turbo", "text-davinci-002","text-davinci-003",
# model_name = model_name_choices[0]

full_path_list = get_all_prompt_file(filePath)
def use_openai(model_name):
    for full_path in full_path_list:
        instruction_list = read_llm_output(full_path)
        print("Processing file name: ", full_path)
        print("Number of instructions: ", len(instruction_list))
        save_file = 'llm_output/' + model_name + "_" + full_path[-9:]
        with jsonlines.open(save_file, mode='a') as outf:

            for instruct in tqdm(instruction_list):
                time.sleep(0.4)
                openai_key = random.choice(openai_key_list)
                response = openai_set(openai_org, openai_key, model_name, instruct)
                clean_output = {}
                clean_output["instruction"] = instruct
                clean_output["response"] = response
                outf.write(clean_output)

for model_name_ in model_name_choices:
    print("Now work on:", model_name_)

    x = use_openai(model_name_)