import os
import argparse
import openai
from tqdm import tqdm

from utils import *
from dataset_utils import read_synth_data, index_example

openai.api_key = os.getenv("OPENAI_API_KEY")


def _parse_args():
    parser = argparse.ArgumentParser()
    add_engine_argumenet(parser)
    # standard, instruction, etc
    parser.add_argument('--style', type=str, default="standard")
    parser.add_argument('--run_prediction', default=False, action='store_true')
    args = parser.parse_args()
    specify_engine(args)
    return args

def result_cache_name(args):
    return "misc/zero_{}_{}.json".format(args.engine_name, args.style)

def zero_shot_prediction(ex, engine, style="standard"):
    if style == "standard":
        prompt = "{}\nQ: {}\nA:".format(ex["context"], ex["question"])
    elif style == "insta":
        prompt = "Answer the question based on the background.\nBackground: {}\nQ: {}\nA:".format(ex["context"], ex["question"])
    elif style == "instb":
        prompt = "Answer the question based on the background.\nBackground: {}\nQuestion: {}\nAnswer:".format(ex["context"], ex["question"])
    elif style == "instc":
        prompt = "Read the background and answer the question.\nBackground: {}\nQuestion: {}\nAnswer:".format(ex["context"], ex["question"])
    else:
        raise RuntimeError("Unsupported prompt style")
    
    resp = openai.Completion.create(engine=args.engine, prompt=prompt, temperature=0.0, max_tokens=25, logprobs=5, echo=True, stop='\n')

    pred = resp["choices"][0]
    pred["prompt"] = prompt
    pred["text"] = pred["text"][len(prompt):]
    return pred

def normalize_prediction(x):
    fields = x.strip().split()
    if fields:
        return fields[0]
    else:
        return x.strip()

def evaluate_qa_predictions(dev_set, predictions, do_print=False):
    acc = 0
    hit_acc = 0
    for ex, pred in zip(dev_set, predictions):
        gt = ex["answer"]
        orig_p = pred["text"]
        hit_acc += gt in orig_p
        p = normalize_prediction(orig_p)
        acc += gt == p
        if do_print:
            if gt == p:
                print("-----------Correct-----------")            
            else:
                print("-----------Wrong-----------")
            print(pred["prompt"])
            print('PR:', p, "\t|\t", pred["text"])
            print('GT:', gt)

    print("ACC", acc / len(dev_set))
    print("HIT ACC", hit_acc / len(dev_set))

def test_zero_shot_performance(args):
    print("Running prediction")
    dev_set = read_synth_data("data/50-dev_synth.json")

    # predictions = [zero_shot_prediction(x) for x in dev_set]
    predictions = []
    for x in tqdm(dev_set, total=len(dev_set), desc="Predicting"):
        predictions.append(zero_shot_prediction(x, args.engine, style=args.style))
    # save
    dump_json(predictions, result_cache_name(args))

    # acc
    evaluate_qa_predictions(dev_set, predictions)

def analyze_zero_shot_performance(args):
    predictions = read_json(result_cache_name(args))
    dev_set = read_synth_data("data/50-dev_synth.json")

    evaluate_qa_predictions(dev_set, predictions, do_print=True)

if __name__=='__main__':
    args = _parse_args()
    if args.run_prediction:
        test_zero_shot_performance(args)
    else:
        analyze_zero_shot_performance(args)