import os
import json
import requests
import openai
from tqdm import tqdm
import argparse
from requests.exceptions import JSONDecodeError


def run(inputs, prompts, output_dir, uri, **kwargs):

    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    json.dump(kwargs, open(os.path.join(output_dir, "params.json"), "w"))
    for k, v in kwargs.items():
        print(k+':'+str(v))

    openai.organization = "your-organization-key"
    openai.api_key = "your-api-key"
    headers = {'Content-Type': 'application/json', 'Authorization': 'Bearer {}'.format(openai.api_key)}
    print("Sending Requests to OpenAI...")
    for idx, task_input in tqdm(inputs):
        folder_dir = os.path.join(output_dir, idx)
        if not os.path.exists(folder_dir):
            os.makedirs(folder_dir, exist_ok=True)

        for prompt_id, (prefix, suffix) in prompts.items():
            if os.path.exists(os.path.join(folder_dir, "{}.json".format(prompt_id))):
                continue
            prompt = prefix + task_input + suffix
            if uri == "https://api.openai.com/v1/chat/completions":
                data = json.dumps(
                    {
                        "messages": [
                            {"role": "system", "content": "You are a helpful assistant at paraphrasing"},
                            {"role": "user", "content": prompt}
                        ],
                        **kwargs
                    })
            else:
                data = json.dumps({"prompt": prompt, **kwargs})
            try:
                response = requests.post(uri, headers=headers, data=data)
            except JSONDecodeError:
                print("Request failed. Retrying...")
                continue
            json.dump(response.json(), open(os.path.join(folder_dir, "{}.json".format(prompt_id)), "w"))


def main():

    argparser = argparse.ArgumentParser()
    argparser.add_argument("--output_dir", type=str, default="./results/negation")
    argparser.add_argument("--input_dir", type=str, default="./data/Zero_Shot/t0.json")

    argparser.add_argument("--model", type=str, default="gpt-4")
    argparser.add_argument("--n", type=int, default=1)
    argparser.add_argument("--temperature", type=float, default=0.3)
    argparser.add_argument("--max_tokens", type=int, default=512)
    argparser.add_argument("--start", type=int, default=174)
    argparser.add_argument("--end", type=int, default=274)
    argparser.add_argument("--use_prompt", type=int, default=[1, 2, 3], nargs="+")

    args = argparser.parse_args()

    print("Used prompt ids: ", args.use_prompt)
    uri = "https://api.openai.com/v1/chat/completions" if args.model in ["gpt-4", "gpt-3.5-turbo"] else "https://api.openai.com/v1/completions"

    params = {
        "model": args.model,
        "n": args.n,
        "temperature": args.temperature,
        "max_tokens": args.max_tokens,
    }

    start = args.start
    end = args.end

    prompts = {
        1: ("Here's an input utterance:\n\n", "\n\n\nNow, your task is to add a negation to the utterance that completely changes its meaning.\nHere's the new utterance:\n\n"),
        2: ("You are given an utterance which is a combination of task instruction and the actual input. Your job is to add a negation to the task instruction and leave the input unchanged. Here's the utterance to be negated:\n\n\n", "\n\n\nNow, generate the new utterance:\n\n\n"),
        3: ("You are provided with the utterance of a specific task and I need you to add negation to it. The actual input, question, and examples in the task should not be changed. You should only add a negation word to the instruction. Task:\n\n\n", "\n\n\nThe paraphrased utterance:\n\n\n")
    }

    remove_prompt = []
    for key in prompts.keys():
        if key not in args.use_prompt:
            remove_prompt.append(key)
    
    for key in remove_prompt:
        prompts.pop(key)

    inputs = [item["input_text"] for item in json.load(open(args.input_dir, "r"))[start:end]]
    inputs = [(str(start + idx), item) for idx, item in enumerate(inputs)]
    run(inputs, prompts, os.path.join(args.output_dir, args.input_dir.split("/")[-1].replace(".json", "")), uri, **params)


if __name__ == "__main__":
    main()
