import pandas as pd
import openai, os, json
import tiktoken
PATH = "."
PATH_DATA = f"{PATH}/promptd_output"
OUTPUT_PATH = f"{PATH}/outputs"
MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "gpt4"
openai.api_key = "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"




text_to_replace = "Identify whether a given statement contains an anachronism:"
replace_with = "Your task is to determine whether a given statement contains an anachronism, which is something or someone not in its correct historical or chronological time. Use analytical and critical thinking skills to identify potential anachronistic elements, considering elements such as language, technology or phenomena that could not have existed at the time the statement is supposed to occur. Print the output in the format \"The answer is \\answer{}\""




text_to_replace = "Let's solve a given mathematical problem in a step-by-step manner to ensure the accuracy of the solution. Please apply a systematic approach to this mathematical operation step by step:\n`[TEST_PROBLEM]`. Print the output in the format \"The answer is \\answer{}\"."
import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    prompt = text_to_replace.replace("[TEST_PROBLEM]", prompt.strip())
    print(prompt)
    # sys.exit(-1)
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    print(response["choices"][0]["message"]["content"])
    print('-'*200)
    return response["choices"][0]["message"]["content"]

@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_completion_mode(prompt):
    response = openai.Completion.create(
        model = MODE_DICT[MODE][MODEL],
        prompt = prompt,
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    # print(response["choices"][0]["text"].strip())
    return response["choices"][0]["text"].strip()

def read_dataset(dataset_name, dataset_file):
    jsn = json.load(open(os.path.join(PATH_DATA, dataset_name, dataset_file)))#[:5]
    return jsn#{"0":jsn["0"]}



def find_errors_in_dataset(dataset):
    error_indices = []
    for idx, d in enumerate(dataset):
        if (type(dataset) == type({}) and dataset[d].get("ins_only_response") is None) or ((type(dataset) == type([]) and d.get("ins_only_response") is None)):
            print(f"d --- {d}   type (d) {type(d)}")
            error_indices.append(idx if type(d) is not type ("") else d)
            continue
        final_response = dataset[d].get("ins_only_response") if type(dataset)==type({}) else d.get("ins_only_response") 
        if(final_response is None or type(final_response)!=type([]) or len(final_response)<=0):
            print("Here---", final_response)
            error_indices.append(idx if type(d) is not type("") else d)
    print('-'*100)
    print(f"Error in {len(error_indices)} indices")
    print('-'*100)
    return error_indices#[:2]
        

import time
MAX_ATTEMPT = 10
def collect_results(dataset, dataset_name):
    attempt = 0
    while(True):
        error_indices = find_errors_in_dataset(dataset)
        # print(error_indices, dataset.keys())
        errors_json = [(dataset[i], i) for i in error_indices]
        # print(errors_json[:3])
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        for idx, value in tqdm.tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # print(value)
            didx, value = value[1], value[0] 
            all_attempts = value["all_attempts"]
            all_outputs = []
            a = value["question"]
            output = get_output_chat_mode(a.strip() + "\n")
            all_outputs.append(output if output is not None else "")
            value["ins_only_response"] = all_outputs
            dataset[didx] = value
        with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_attempt_{str(attempt+1)}_ins_only.json", "w") as f:
            json.dump(dataset, f, indent = 4)
        attempt += 1
        # if(not(len(errors_json)<=0 or attempt==MAX_ATTEMPT)):
        time.sleep(60*1)
    with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_final_ins_only.json", "w") as f:
        json.dump(dataset, f, indent = 4)

        




import tqdm
if __name__ == "__main__":
    dataset_name = "gsm8k"
    dataset_file = "gsm8k_gpt-4_final.json"
    prompt_dataset = read_dataset(dataset_name, dataset_file)
    collect_results(prompt_dataset, dataset_name)


#anachornisms --- Accuracy: 0.8043478260869565
