import json, openai
address = "promptd_output/maths/MATHS_easy_to_very_easy_gpt-4_final.json"

MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "gpt4"
openai.api_key = "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"


import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    # print(prompt)
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    print(response["choices"][0]["message"]["content"])
    return response["choices"][0]["message"]["content"]

def prepare_prompt(q_res):
    delimiter = "###Better Prompt###"
    question = q_res["final_response"][0].split(delimiter)[-1].strip()#the rewritten question
    only_question = question.split("[asy]")[0].strip()#the rewritten question without diagrams
    original_question = q_res["zero_shot_question"].strip()#the original question
    try:
        only_original_diagrams = original_question.split("[asy]")[1:]#the original question without diagrams
    except:
        only_original_diagrams = []
    diagram = ""
    for dia in only_original_diagrams:
        # print(dia)
        diagram += "[asy]" + dia + " "
    prompt = only_question + ("\nThe diagram is: " + diagram if len(only_original_diagrams)>0 else "")
    return prompt

gsm8k_examples_only_instruction = "Consider solving a word problem step by step. Begin by reading the problem thoroughly to ensure a complete understanding. Identify all the variables and the relationships between them. Next, break the problem into smaller manageable parts and develop a strategy to solve each part. Lastly, review the solutions of all parts and check whether the final solution makes sense in the context of the original problem."

from tqdm import tqdm
import os, sys
def prep_MATHS(jsn):
    for _, j in tqdm(enumerate(jsn), total = len(jsn), desc = "Eval MATHS..."):
        # prompt = gsm8k_examples_only_instruction + '\n' + prepare_prompt(j)
        # prompt = prepare_prompt(j)
        prompt = gsm8k_examples_only_instruction + '\n' + j["question"]
        print("-"*100)
        print(prompt)
        # sys.exit(-1)
        print("-"*100)
        response = get_output_chat_mode(prompt)
        # print(j["zero_shot_question"])
        print("Solution\n", response)
        print("-"*200)
        # _ = input("Cotinue?")
        # os.system("clear")
        j["promptD_output"] = response
    with open(address.replace(".json", "_ins_only.json"), "w") as f:
        json.dump(jsn, f, indent=4)



jsn = json.load(open(address))#[:1]

prep_MATHS(jsn)