import os
import openai
import argparse
import json
openai_api_key = os.environ["OPENAI_API_KEY"]
openai.api_key = openai_api_key
import pprint
from tqdm import tqdm
from time import sleep
import logging
from openai.error import RateLimitError, APIConnectionError, InvalidRequestError
from tqdm.contrib.logging import logging_redirect_tqdm
MAX_PROMPTS = -1

json_to_api = {
    "human": "user",
    "gpt": "assistant"
}

api_to_json = {
    "user": "human",
    "assistant": "gpt"
}

LOG = logging.getLogger(__name__)

def call_to_open_api(prompts, outputs_file, model_name, log_level=logging.INFO, num_choices=1):
    """
    prompts: json object containing prompts
    outputs_file: file to write outputs to
    model_name: name of model to use
    
    returns: the same output that was written to outputs_file
    
    """
    logging.basicConfig(level=log_level)
    # get system message describing the task
    header = prompts["header"]
    # get example conversation
    example = prompts["example"]
    # get list of prompts
    prompts = prompts["prompts"]
    template_message = [{"role": "system", "content": header}]
    # add example conversation to prompt_message
    for _id, conversation in example["conversations"].items():
        # print(_id, conversation)
        for dialogue in conversation:
            # print(dialogue)
            template_message.append({"role": json_to_api[dialogue["from"]], "content": dialogue["value"]})    
        template_message.append({"role": "user", "content": f"Let's start afresh\n{header}"})
        template_message.append({"role": "assistant", "content": "Alright. Let's begin"})
    output = []
    # iterate through prompts
    already_done = set()
    if os.path.isfile(outputs_file):
        output = json.load(open(outputs_file))
        already_done = set([x["id"] for x in output])
        LOG.info(already_done)
    try:
        with logging_redirect_tqdm():
            for i, prompt in tqdm(list(enumerate(prompts))):
                if i == MAX_PROMPTS:
                    break
                id = prompt["id"]
                LOG.info(f"Generating for {id}")
                if id in already_done:
                    continue
                prompt_message = template_message.copy()
                # build up prompt
                output_conversation = []
                # making a progress bar for this for loop because it's slow
                for conversation in prompt["conversations"]:
                    output_conversation.append(conversation)
                    prompt_message.append({"role": json_to_api[conversation["from"]],
                                        "content": conversation["value"]})
                    # generate response
                    while True:
                        try:
                            response = openai.ChatCompletion.create(
                                model=model_name,
                                messages=prompt_message,
                                temperature=0.1,
                                max_tokens=1024,
                                top_p=1,
                                n=num_choices,
                                frequency_penalty=0,
                                presence_penalty=0.6,
                                stop=[]
                            )
                            break
                        except (RateLimitError, APIConnectionError) as e:
                            LOG.warn(f'OpenAI API got err {e}')
                            LOG.warn('Retrying after 30s.')
                            sleep(10)
                        except Exception as e:
                            LOG.warn(e)
                            LOG.warn("something went wrong, try again in 5 sec")
                            sleep(5)
                    for choice in response["choices"]:
                        assistant_response = choice["message"]["content"].strip()
                        LOG.info(assistant_response)
                        output_conversation.append({"from": "gpt", "value": assistant_response})
                    prompt_message.append({"role": "assistant", "content": response["choices"][0]["message"]["content"].strip()})
                # add output conversation to output
                output_prompt = prompt.copy()
                output_prompt["conversations"] = output_conversation
                output.append(output_prompt)
                with open(outputs_file, "w") as wf:
                    json.dump(output, wf, indent=4)
    except KeyboardInterrupt:
        print(f"Stopping at {len(output)}")
    # write outputs to file
    return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # add arguments for file including all prompts, and output file (both in json format)
    parser.add_argument("--prompts-file", type=str, required=True)
    # example: prompts.json
    """
    {
        "header": "This describes the task",
        "example":
        {
        "conversations":{
            "id1": [
                {
                    "from": "human",
                    "value": "This is an example prompt"
                },
                {
                    "from": "gpt",
                    "value": "This is an example response"
                },
                {
                    "from": "human",
                    "value": "This is an example continuation"
                },
                {  
                    "from": "gpt",
                    "value": "This is an example response to the continuation"
                }
            ],
            ...
            }
        },
        "prompts":
        [
            {
            "id": "000000379143",
            "conversations":
                [
                    {
                        "from": "human",
                        "value": "Example prompt 2"
                    },
                    {
                        "from": "human",
                        "value": "Example continuation 2"
                    }
                ]
            }
        ]
    }
    """
    parser.add_argument("--outputs-file", type=str, required=True)
    # example: outputs.json
    """
    [
        {
        "id": "000000379143",
        "conversations":
            [
                {
                    "from": "human",
                    "value": "Example prompt 2"
                },
                {
                    "from": "gpt",
                    "value": "Example response 2"
                },
                {
                    "from": "human",
                    "value": "Example continuation 2"
                },
                {
                    "from": "gpt",
                    "value": "Example response to continuation 2"
                }
            ]
        }
    ]
    """
    parser.add_argument("--model", default="gpt-3.5-turbo", help="OpenAI model to use")
    pp = pprint.PrettyPrinter(indent=4)
    # pp.pprint(completion['choices'][0]['message']['content'])
    # exit()
    args = parser.parse_args()
    prompts_file = args.prompts_file
    outputs_file = args.outputs_file
    with open(prompts_file) as f:
        prompts = json.load(f)
        call_to_open_api(prompts, outputs_file, args.model)
