import json, openai
from tqdm import tqdm
MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "gpt4"
openai.api_key = "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"



gsm8k_few_shot_demos = """
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there
will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have
been 21 - 15 = 6. The answer is 6.

Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

Q: 
"""

entailment_few_shot_demos = """
Q: John is a bachelor. So, John is married.
A: A "bachelor" is an unmarried man. Therefore, if John is a bachelor, it cannot be concluded that he is married.
Answer: No Entailment

Q: All mammals are vertebrates. A cat is a vertebrate. 
A: Cat is mammals and all mammals are vertebrates. 
Answer: Entailment

Q:
"""

hallucinations_few_shot_demos = """
Q:
A:

Q:
A:

Q:
"""

date_understanding_few_shot_demos = """
Q: It is 12/05/1969 today. What is the date one week ago from today in MM/DD/YYYY?
A: Today is 12/05/1969
One week ago from today would be 11/28/1969.
Answer: 11/28/1969.

Q: Sarah was born on the last day of Feburary in 2001. Today is her 20-year-old birthday. What is the date 20 days ago in MM/DD/YYYY?
A: Sarah was Born on 02/28/2001
So 20 year after, today will be 02/28/2021
20 days ago in MM/DD/YYYY would be 02/08/2021.
Answer: 02/08/2021.

Q:
"""

anachronisms_few_shot_demos = """
Q: Ancient Romans used smartphones to communicate.
A: Smartphones were not available during the time of Ancient Romans.
Answer: Yes

Q: The steam engine was a major invention during the Industrial Revolution.
A: Steam engines were invnted during Industrial revolution so no inconsistency.
Answer: No

Q:
"""

sports_few_shot_demos = """
Q:
A:

Q:
A:

Q:
"""

import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    # print(prompt)
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    # print(response["choices"][0]["message"]["content"])
    return response["choices"][0]["message"]["content"]


def prepare_prompt_date_understanding(data):
    test_instance = data["input"]
    prompt = date_understanding_few_shot_demos.strip() + " " + test_instance + "\nA: "
    print(prompt)
    print('-'*100)
    return prompt

def prepare_anachronisms(data):
    test_instance = data["input"]
    prompt = anachronisms_few_shot_demos.strip() + " " + test_instance + "\nA: "
    print(prompt)
    print('-'*100)
    return prompt



def prepare_entailment(data):
    test_instance = data["input"]
    prompt = entailment_few_shot_demos.strip() + " " + test_instance + "\nA: "
    print(prompt)
    print('-'*100)
    return prompt

def prepare_prompt_gsm8k(data):
    test_instance = data["question"]
    prompt = gsm8k_few_shot_demos.strip() + " " + test_instance + "\nA: "
    print(prompt)
    print('-'*100)
    return prompt

def read_dataset(dataset_name, dataset_address):
    if(dataset_name.find("gsm")>=0):
        dataset = [json.loads(x) for x in open(dataset_address)]
    else:
        dataset = json.load(open(dataset_address))["examples"]
    return dataset


prompt_caller = {"entailment":prepare_entailment, "gsm8k":prepare_prompt_gsm8k, "date_understanding": prepare_prompt_date_understanding, "anachronisms": prepare_anachronisms}

def evaluate_few_shot(dataset_name, dataset_address):
    dataset = read_dataset(dataset_name, dataset_address)
    # print(dataset)
    for _, d in tqdm(enumerate(dataset), total = len(dataset), desc = f"Prompting {MODEL} for {dataset_name}"):
        # print(d, prompt_caller[dataset_name])
        prompt = prompt_caller[dataset_name](d)
        response = get_output_chat_mode(prompt)
        print(response)
        print('-'*200)
        d["few_shot_output"] = response
        d["few_shot_question"] = prompt
    with open(f"./outputs/{dataset_name}_few_shot_{MODEL}.json", "w") as f:
        json.dump(dataset, f, indent = 4)


if __name__ == "__main__":
    dataset_name = "entailment"
    dataset_address = "data/analytical_entailment/analytical_entailment.json"
    evaluate_few_shot(dataset_name, dataset_address)




