import openai
import os
import json
import time
import math
from pathlib import Path
from tqdm import tqdm
import argparse
from multiprocessing import Pool, Lock

api_key_pool = [ # your api pool

]



def query_openai_api_per_example(api_key,
                                 prompt, 
                                 instance,
                                 model,
                                 sleep_second,
                                 max_tokens,
                                 demonstrations=None):
    openai.api_key = api_key
    # print(api_key)
    input = instance["instance"]["input"]["text"]
    s_time = time.time()
    success = False
    if model == "gpt-3.5-turbo":
        messages = [
            {"role": "system", "content": "You are a helpful, pattern-following assistant."}
        ]
        messages.append(
            {
                "role": "user",
                "content": prompt["instructions"]
            }
        )
        if demonstrations is not None:
            for example in demonstrations:
                messages.append(
                    {
                        "role": "user",
                        "content": prompt["input_prefix"] + example["input"] + prompt["input_suffix"]
                    }
                )
                messages.append(
                    {
                        "role": "assistant",
                        "content": prompt["output_prefix"] + example["output"] + prompt["output_suffix"]
                    }
                )
        messages.append({"role": "user", "content": prompt["input_prefix"] + input + prompt["input_suffix"]})
        while not success:
            try:
                response = openai.ChatCompletion.create(
                    model = model, 
                    messages = messages,
                    max_tokens = max_tokens
                )
            except Exception as e:
                print(api_key)
                print(e)
                time.sleep(sleep_second)
            else:
                success = True
        result = response['choices'][0]['message']['content']
    elif model == "text-davinci-002" or model == "text-davinci-003":
        text_demonstrations = prompt["instructions"] + "\n"
        if demonstrations is not None:
            for example in demonstrations:
                text_demonstrations += prompt["input_prefix"] + example["input"] + prompt["input_suffix"] + \
                                        prompt["output_prefix"] + example["output"] + prompt["output_suffix"]
        while not success:
            try:
                response = openai.Completion.create(
                    engine = model, 
                    prompt = text_demonstrations + input, 
                    max_tokens = max_tokens
                )
            except:
                print(api_key)
                print(e)
                time.sleep(sleep_second)
            else:
                success = True
        result = response['choices'][0]['text']
    else:
        raise ValueError
    instance["request"] = {
        "result": {
            "success": success,
            "completions": [{"text": result}],
        },
        "request_time": time.time() - s_time,
        "request_datetime": time.time()
    }
    return instance['request']


def main(args):
    # the number of threads is less than or equal to the number of api keys
    args.n_threads = min(len(api_key_pool), args.n_threads)

    # train
    train_file = os.path.join(args.input_dir, "train.json")

    # test
    if args.input_file is None:
        path = os.path.join(args.input_dir, "test.json")
        if os.path.exists(path):
            test_file = path
        else:
            test_file = os.path.join(args.input_dir, "dev.json")
            
    else:
        test_input_file = args.input_file
        test_output_file = test_input_file

    # demonstrations
    examples = []
    train_data = json.load(open(train_file))
    for instance in train_data["request_states"]:
        examples.append({
            "input": instance["instance"]["input"]["text"],
            "output": instance["instance"]["references"][0]["output"]["text"]
        })
        assert instance["instance"]["split"] == "train"

    with open(test_file, 'r+') as f:
        data = json.load(f)
        prompt = data['prompt']
        batch = []
        for sample in tqdm(data['request_states']):
            if not sample['request']['result'].get('success', False):
                batch.append(sample)
                if len(batch) == args.n_threads:
                    with Pool(args.n_threads) as p:
                        requests = p.starmap(query_openai_api_per_example, [(api_key_pool[i], prompt, batch[i], args.model, args.sleep_second, args.max_tokens, examples) for i in range(args.n_threads)])
                    for b, request in zip(batch, requests):
                        b['request'] = request
                    batch.clear()
                    f.seek(0)
                    json.dump(data, f, indent=4)
                    f.truncate()
        
        # process the rest samples
        if len(batch) != 0:
            with Pool(args.n_threads) as p:
                requests = p.starmap(query_openai_api_per_example, [(api_key_pool[i], prompt, batch[i], args.model, args.sleep_second, args.max_tokens, examples) for i in range(len(batch))])
            for b, request in zip(batch, requests):
                b['request'] = request
            batch.clear()
            f.seek(0)
            json.dump(data, f, indent=4)
            f.truncate()



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Query OpenAI")
    # multiple processing
    parser.add_argument("--n_threads", type=int, default=8)
    # I/O
    parser.add_argument("--input_dir", type=str, default="prompts")
    parser.add_argument("--input_file", type=str, default=None)

    # model & parameters
    parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
    parser.add_argument("--max_tokens", type=int, default=512)
    parser.add_argument("--sleep_second", type=float, default=20.0)

    args = parser.parse_args()
    main(args)
    