import re
import openai
import time


class Decoder():
    def __init__(self):
        pass

    def decode(self, args, input, CNT_SUM):
        response = decoder_for_gpt3(args, input, CNT_SUM)
        return response

# Sentence Generator (Decoder) for GPT-3 ...
def decoder_for_gpt3(args, input, CNT_SUM):
    
    # GPT-3 API allows each users execute the API within 60 times in a minute ...
    time.sleep(args.api_time_interval)
    OPENAI_API_KEYs = [
        "", # FILL IN WITH YOUR OPENAI API KEY
    ]
    
    # https://beta.openai.com/account/api-keys
    openai.api_key = OPENAI_API_KEYs[0]
    engine = "code-davinci-002"
        
    response = openai.Completion.create(
      engine=engine,
      prompt=input,
      max_tokens=args.max_length,
      temperature=0,
      stop=['--', '\n\n', '#'],
    )
    
    return response["choices"][0]["text"]

# ver 0.2
def answer_cleansing(args, pred):

    print("pred_before : " + pred)
    
    preds = pred.split(args.direct_answer_trigger_for_fewshot)
    answer_flag = True if len(preds) > 1 else False 
    pred = preds[-1]

    if args.dataset in ("aqua", "commonsensqa"):
        pred = re.findall(r'A|B|C|D|E', pred)
    elif args.dataset == "bigbench_date":
        pred = re.findall(r'A|B|C|D|E|F', pred)
    elif args.dataset in ("object_tracking"):
        pred = re.findall(r'A|B|C', pred)
    elif args.dataset in ("gsm8k", "addsub", "multiarith", "svamp", "singleeq"):
        pred = pred.replace(",", "")
        pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
    elif args.dataset in ("strategyqa", "coin_flip"):
        pred = pred.lower()
        pred = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", pred)
        pred = pred.split(" ")
        pred = [i for i in pred if i in ("yes", "no")]
    elif args.dataset == "last_letters":
        pred = re.sub("\"|\'|\n|\.|\s","", pred)
        pred = [pred]
    else:
        raise ValueError("dataset is not properly defined ...")

    # If there is no candidate in list, null is set.
    if len(pred) == 0:
        pred = ""
    else:
        if answer_flag:
            # choose the first element in list ...
            pred = pred[0]
        else:
            # choose the last element in list ...
            pred = pred[-1]
    
    # (For arithmetic tasks) if a word ends with period, it will be omitted ...
    if pred != "":
        if pred[-1] == ".":
            pred = pred[:-1]
    
    print("pred_after : " + pred)
    return pred


if __name__ == "__main__":
    
    test_str = "/* Create a JavaScript dictionary of 5 countries and capitals: */\n"
    z = decoder_for_gpt3(None, test_str, 256, 0)
    print(z)