import pandas as pd
import openai, os, json
import tiktoken
PATH = "."
PATH_DATA = f"{PATH}/data"
OUTPUT_PATH = f"{PATH}/outputs"
MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003", "instructGPT": "gpt-3.5-turbo-instruct"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "gpt4"
openai.api_key = "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"





import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    print('>>>>IP', prompt, 'IP<<<<')
    # xxxx
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    # print(response["choices"][0]["message"]["content"])
    return response["choices"][0]["message"]["content"]


prompt = [lines for lines in open(
        f'{PATH_DATA}/new_output_prompt.txt', encoding="utf-8")]
promptD = ''.join(prompt)
import sys
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_completion_mode(prompt):
    prompt = promptD + "\nIdentify whether the following statement contains an anachronism: \"" + prompt.strip() + "\"\nYour Choices are (A) True (B) False\n###Reason###\n"
    print(prompt)
    # sys.exit(-1)
    response = openai.Completion.create(
        model = MODE_DICT[MODE][MODEL],
        prompt = prompt,
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0,
        max_tokens=500,
        )
    print(response["choices"][0]["text"].strip())
    # sys.exit(-1)
    return response["choices"][0]["text"].strip()

def read_dataset(dataset_name, dataset_file):
    jsn = json.load(open("./outputs/anachronisms-davinci-003_davinci_rew.json"))#[:5]
    return jsn#["examples"]



def find_errors_in_dataset(dataset):
    error_indices = []
    for idx, d in enumerate(dataset):
        if (type(dataset) == type({}) and dataset[d].get(f"{MODEL}_final_response") is None) or ((type(dataset) == type([]) and d.get(f"{MODEL}_final_response") is None)):
            print(f"d --- {d}   type (d) {type(d)}")
            error_indices.append(idx if type(d) is not type ("") else d)
            continue
        final_response = dataset[d].get(f"{MODEL}_final_response") if type(dataset)==type({}) else d.get(f"{MODEL}_final_response") 
        if(final_response is None or type(final_response)!=type([]) or len(final_response)<=0):
            print("Here---", final_response)
            error_indices.append(idx if type(d) is not type("") else d)
    print('-'*100)
    print(f"Error in {len(error_indices)} indices")
    print('-'*100)
    return error_indices#[:2]
        
# a = value["question"]
#             output = get_output_chat_mode(a.strip() + "\n")
#             all_outputs.append(output if output is not None else "")
#             value["ins_only_response"] = all_outputs
#             dataset[didx] = value
import time
MAX_ATTEMPT = 10
def collect_results(dataset, dataset_name):
    attempt = 0
    while(True):
        error_indices = find_errors_in_dataset(dataset)
        # print(error_indices, dataset.keys())
        errors_json = [(dataset[i], i) for i in error_indices]
        # print(errors_json[:3])
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        for idx, value in tqdm.tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # print(value)
            didx, value = value[1], value[0] 
            # all_attempts = value["all_attempts"]
            all_outputs = []
            # for a in all_attempts:
            a = value["davinci_rewrite"]
            
            #output = get_output_completion_mode(a.strip() + "\n")
            output = get_output_chat_mode(a[0].strip().split("###Better Prompt###")[-1]).strip()
            print('>>>', output)
            # sys.exit(-1)
            all_outputs.append(output if output is not None else "")
            value[f"{MODEL}_question"] = a
            value[f"{MODEL}_rewrite"] = all_outputs
            dataset[didx] = value
        with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_{str(attempt+1)}_rewritten_turbo_gpt4_answer.json", "w") as f:
            json.dump(dataset, f, indent = 4)
        attempt += 1
        break
        # if(not(len(errors_json)<=0 or attempt==MAX_ATTEMPT)):
        time.sleep(60*1)
    with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_davinci_rew.json", "w") as f:
        json.dump(dataset, f, indent = 4)

        




import tqdm
if __name__ == "__main__":
    dataset_name = "gsm8k"
    dataset_file = "gsm8k_subset.json"
    prompt_dataset = read_dataset(dataset_name, dataset_file)
    collect_results(prompt_dataset, dataset_name)
