import openai
import os
import json
import time
import math
from pathlib import Path
from tqdm import tqdm
import argparse
from multiprocessing import Pool, Lock


api_key_pool = [ # your api pool

]



def query_openai_api_per_example(model, 
                                 prompt, 
                                 input,
                                 demonstrations=None,
                                 max_tokens=512):
    if model == "gpt-3.5-turbo":
        messages = [
            {"role": "system", "content": "You are a helpful, pattern-following assistant."}
        ]
        messages.append(
            {
                "role": "user",
                "content": prompt["instructions"]
            }
        )
        if demonstrations is not None:
            for example in demonstrations:
                messages.append(
                    {
                        "role": "user",
                        "content": prompt["input_prefix"] + example["input"] + prompt["input_suffix"]
                    }
                )
                messages.append(
                    {
                        "role": "assistant",
                        "content": prompt["output_prefix"] + example["output"] + prompt["output_suffix"]
                    }
                )
        messages.append({"role": "user", "content": input})
        response = openai.ChatCompletion.create(
            model = model, 
            messages = messages,
            max_tokens = max_tokens
        )
        return response['choices'][0]['message']['content']
    elif model == "text-davinvi-002" or model == "text-davinci-003":
        text_demonstrations = prompt["instructions"] + "\n"
        if demonstrations is not None:
            for example in demonstrations:
                text_demonstrations += prompt["input_prefix"] + example["input"] + prompt["input_suffix"] + \
                                        prompt["output_prefix"] + example["output"] + prompt["output_suffix"]
        response = openai.Completion.create(
            engine = model, 
            prompt = text_demonstrations + input, 
            max_tokens = max_tokens
        )
        return response['choices'][0]['text']
    else:
        raise ValueError


def query_openai_api(idx, 
                     n,
                     api_key,
                     data_chunk,
                     prompt,
                     demonstrations,
                     output_folder,
                     args):
    openai.api_key = api_key
    out_file = open(os.path.join(output_folder, f"{args.range[0]}_{args.range[1]}_{idx}.jsonl"), "w")
    s_time = 0
    if idx == 0:
        data_chunk = tqdm(data_chunk)
    for i, instance in enumerate(data_chunk):
        response = None
        if time.time() - s_time < args.sleep_second:
            time.sleep(args.sleep_second - time.time() + s_time)
        s_time = time.time()
        response = query_openai_api_per_example(args.model, prompt, instance["instance"]["input"]["text"], demonstrations, args.max_tokens)
        exit()
        try:
            response = query_openai_api_per_example(args.model, prompt, instance["instance"]["input"]["text"], demonstrations, args.max_tokens)
        except:
            print(idx)
            part = int(math.ceil((args.range[1] - args.range[0]) / n))
            range_start = range(args.range[0], args.range[1], part)[idx]
            thread_range = range(range_start, min(range_start + part, args.range[1]))
            not_processed_range = [thread_range[i], min(range_start + part, args.range[1])]
            print(f'Thread {idx} wrong! Need to rerun this range of documents!. Not processed data range: {not_processed_range}')
            break
        assert response is not None
        instance["request"] = {
            "result": {
                "success": True,
                "completions": [{"text": response}],
            },
            "request_time": time.time() - s_time,
            "request_datetime": time.time()
        }
        out_file.write(json.dumps(instance)+"\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Query OpenAI")
    # multiple processing
    parser.add_argument("--n_threads", type=int, default=8)
    # I/O
    parser.add_argument("--input_dir", type=str, default="prompts")
    parser.add_argument("--output_dir", type=str, default="outputs")
    parser.add_argument("--range", type=int, nargs='+', default=[0, -1], 
                        help='the range of evaluated documents is [args.range[0], args.range[1]). \
                            -1 means evaluate until the last document (included)')

    # model & parameters
    parser.add_argument("--model", type=str, default="text-davinci-003")
    parser.add_argument("--max_tokens", type=int, default=512)
    parser.add_argument("--sleep_second", type=float, default=2.0)


    args = parser.parse_args()

    # the number of threads is equal to the number of api keys
    args.n_threads = len(api_key_pool)

    assert len(args.range) == 2
    
    # mkdir
    output_folder = Path(args.output_dir)
    output_folder.mkdir(exist_ok=True, parents=True)

    # train
    train_file = os.path.join(args.input_dir, "train.json")

    # test
    path = os.path.join(args.input_dir, "test.json")
    if os.path.exists(path):
        test_file = path
    else:
        test_file = os.path.join(args.input_dir, "dev.json")

    # demonstrations
    examples = []
    train_data = json.load(open(train_file))
    for instance in train_data["request_states"]:
        examples.append({
            "input": instance["instance"]["input"]["text"],
            "output": instance["instance"]["references"][0]["output"]["text"]
        })
        assert instance["instance"]["split"] == "train"

    # test data
    test_data = json.load(open(test_file))
    prompt = test_data["prompt"]
    ## divide into multiple splits
    n = int(args.n_threads)
    if args.range[1] == -1:
        args.range[1] = len(test_data["request_states"])
    print(f'total range: [0, {len(test_data["request_states"])}); evaluated range: [{args.range[0]}, {args.range[1]})')
    part = int(math.ceil((args.range[1] - args.range[0]) / n))
    test_data_chunks = [test_data["request_states"][i: min(i + part, args.range[1])] for i in range(args.range[0], args.range[1], part)]
    # part = int(math.ceil(len(test_data["request_states"]) / n))
    # test_data_chunks = [test_data["request_states"][i:i+part] for i in range(0, len(test_data["request_states"]), part)]
    i = 0
    query_openai_api(i, n, api_key_pool[i], test_data_chunks[i], prompt, examples, output_folder, args)
    exit()
    # multiple threads processing
    p = Pool(n)
    for i in range(n):
        p.apply_async(query_openai_api, args=(i, n, api_key_pool[i], test_data_chunks[i], prompt, examples, output_folder, args))
    p.close()
    p.join()