import random
random.seed(1331)
ICL_gsm = "Q: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nA: He eats 32 from the largest pizzas because 2 x 16 = <<2*16=32>>32\nHe eats 16 from the small pizza because 2 x 8 = <<2*8=16>>16\nHe eats 48 pieces because 32 + 16 = <<32+16=48>>48\n#### 48\nAnswer: 48\n\nQ: Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?\nA: Half of the number of Randy's mango trees is 60/2 = <<60/2=30>>30 trees.\nSo Randy has 30 - 5 = <<30-5=25>>25 coconut trees.\nTherefore, Randy has 60 + 25 = <<60+25=85>>85 treeson his farm.\n#### 85\nAnswer 85\n\nQ: "
ICL_gsm = "Given a set of word problems, your task is to break down each problem step by step, performing all the mathematical operations, and provide a detailed solution. Keep in mind that you should provide a clear explanation for each step. The solution to each problem should be represented as a separate paragraph, and it should eventually lead to a final solution. Using the information from the problem, make all necessary calculations and conclude with the final answer. Then, print the output in the format \"The answer is \\answer{}.\""
ICL_ins_ana = "Given a statement, indicate if it contains an anachronism. An anachronism is an error in the timing of events or objects, or anything that belongs to a period other than that being portrayed. If the statement contains an anachronism, print \"Yes\", otherwise print \"No\". Make the output in a common format like \"The answer is \\answer{}\"."
ICL_only_ana = "Your task is to determine if the historical context in the input passage is plausible given known timelines and the invention and release dates of certain technologies, such as radios and iPhones. If the historical context and the technologies mentioned match, print 'Yes'. If they do not, print 'No'. Print the output in the format \"The answer is \\answer{}\"."


import pandas as pd
import openai, os, json
import tiktoken
PATH = "."
PATH_DATA = f"{PATH}/outputs"
OUTPUT_PATH = f"{PATH}/outputs"
MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "gpt4"
openai.api_key = "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"

import backoff, requests, sys
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    #prompt + "\n" + candidate_prompt + "\n###Reason###\n"
    prompt = ICL_only_ana + "\n" + prompt.strip() #+ "\n###Reason###\n"
    print(prompt)
    # sys.exit(-1)
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0,
        max_tokens=3000,
        )
    print(response["choices"][0]["message"]["content"])
    print('-'*100)
    # sys.exit(-1)
    return response["choices"][0]["message"]["content"]

@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_completion_mode(prompt):
    response = openai.Completion.create(
        model = MODE_DICT[MODE][MODEL],
        prompt = prompt,
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    # print(response["choices"][0]["text"].strip())
    return response["choices"][0]["text"].strip()

def read_dataset(dataset_name, dataset_file):
    jsn = json.load(open(os.path.join(PATH_DATA, dataset_name, dataset_file)))#[:5]
    return jsn

def read_MATHS(dataset_address):
    final_dataset, count = [], 0
    jsn = json.load(open(dataset_address))
    for tasks in jsn.keys():
        for tasks_difficulty in jsn[tasks].keys():
            task_problems = jsn[tasks][tasks_difficulty]
            for problem in task_problems:
                id = tasks + "_" + tasks_difficulty + str(count)
                count += 1
                new_dict = {"id": id, "question": problem["problem"], "solution": problem["solution"]}
                final_dataset.append(new_dict)
    return final_dataset
            


def find_errors_in_dataset(dataset):
    error_indices = []
    for idx, d in enumerate(dataset):
        if (type(dataset) == type({}) and dataset[d].get("final_response") is None) or ((type(dataset) == type([]) and d.get("final_response") is None)):
            print(f"d --- {d}   type (d) {type(d)}")
            error_indices.append(idx if type(d) is not type ("") else d)
            continue
        final_response = dataset[d].get("final_response") if type(dataset)==type({}) else d.get("final_response") 
        if(final_response is None or type(final_response)!=type([]) or len(final_response)<=0):
            print("Here---", final_response)
            error_indices.append(idx if type(d) is not type("") else d)
    print('-'*100)
    print(f"Error in {len(error_indices)} indices")
    print('-'*100)
    return error_indices#[:2]
        
# a = value["question"]
#             output = get_output_chat_mode(a.strip() + "\n")
#             all_outputs.append(output if output is not None else "")
#             value["ins_only_response"] = all_outputs
#             dataset[didx] = value
import time
MAX_ATTEMPT = 10
def collect_results(dataset, dataset_name):
    attempt = 0
    while(True):
        error_indices = find_errors_in_dataset(dataset)
        # print(error_indices, dataset.keys())
        errors_json = [(dataset[i], i) for i in error_indices]
        # print(errors_json[:3])
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        for idx, value in tqdm.tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # print(value)
            didx, value = value[1], value[0] 
            all_outputs = []
            # for a in all_attempts:
            a = value["question"]
            output = get_output_chat_mode(a.strip() + "\n")
            #print(value["solution"] if value.get("solution") is not None else value["target_answer"])
            print('-'*200)
            all_outputs.append(output if output is not None else "")
            value["IC_only_question"] = a
            value["IC_only_response"] = all_outputs
            dataset[didx] = value
        with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_attempt_IC_only{str(attempt+1)}.json", "w") as f:
            json.dump(dataset, f, indent = 4)
        attempt += 1
        # if(not(len(errors_json)<=0 or attempt==MAX_ATTEMPT)):
        time.sleep(60*1)
    with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_IC_only_final.json", "w") as f:
        json.dump(dataset, f, indent = 4)

        




import tqdm
if __name__ == "__main__":
    dataset_name = "anachronisms"
    dataset_file = "anachronisms_new_final.json"
    prompt_dataset = read_dataset("anachronisms", dataset_file)
    # prompt_dataset = read_MATHS('data/MATHS/maths_med_to_hard.json')
    collect_results(prompt_dataset, "anachornisms")
