import pandas as pd
import openai, os, json
import tiktoken
PATH = "."
PATH_DATA = f"{PATH}/outputs"
OUTPUT_PATH = f"{PATH}/outputs"
MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "gpt4"
openai.api_key = "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"



promtd = [lines for lines in open(f'./data/new_output_prompt.txt', encoding="utf-8")]
promtd = ''.join(promtd)


import backoff, requests, sys
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    #prompt + "\n" + candidate_prompt + "\n###Reason###\n"
    prompt = prompt.strip()
    # print(prompt)
    # sys.exit(-1)
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0,
        max_tokens=3000,
        )
    print(response["choices"][0]["message"]["content"])
    print('-'*100)
    # sys.exit(-1)
    return response["choices"][0]["message"]["content"]

@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_completion_mode(prompt):
    response = openai.Completion.create(
        model = MODE_DICT[MODE][MODEL],
        prompt = prompt,
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    # print(response["choices"][0]["text"].strip())
    return response["choices"][0]["text"].strip()

def read_dataset(dataset_name, dataset_file):
    jsn = json.load(open(os.path.join(PATH_DATA, dataset_name, dataset_file)))#[:5]
    return jsn

def read_MATHS(dataset_address):
    final_dataset, count = [], 0
    jsn = json.load(open(dataset_address))
    for tasks in jsn.keys():
        for tasks_difficulty in jsn[tasks].keys():
            task_problems = jsn[tasks][tasks_difficulty]
            for problem in task_problems:
                id = tasks + "_" + tasks_difficulty + str(count)
                count += 1
                new_dict = {"id": id, "question": problem["problem"], "solution": problem["solution"]}
                final_dataset.append(new_dict)
    return final_dataset
            


def find_errors_in_dataset(dataset):
    error_indices = []
    for idx, d in enumerate(dataset):
        if (type(dataset) == type({}) and dataset[d].get("final_response") is None) or ((type(dataset) == type([]) and d.get("final_response") is None)):
            print(f"d --- {d}   type (d) {type(d)}")
            error_indices.append(idx if type(d) is not type ("") else d)
            continue
        final_response = dataset[d].get("final_response") if type(dataset)==type({}) else d.get("final_response") 
        if(final_response is None or type(final_response)!=type([]) or len(final_response)<=0):
            print("Here---", final_response)
            error_indices.append(idx if type(d) is not type("") else d)
    print('-'*100)
    print(f"Error in {len(error_indices)} indices")
    print('-'*100)
    return error_indices#[:2]
        
# a = value["question"]
#             output = get_output_chat_mode(a.strip() + "\n")
#             all_outputs.append(output if output is not None else "")
#             value["ins_only_response"] = all_outputs
#             dataset[didx] = value
import time
MAX_ATTEMPT = 10
def collect_results(dataset, dataset_name):
    attempt = 0
    while(True):
        error_indices = find_errors_in_dataset(dataset)
        # print(error_indices, dataset.keys())
        errors_json = [(dataset[i], i) for i in error_indices]
        # print(errors_json[:3])
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        for idx, value in tqdm.tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # print(value)
            didx, value = value[1], value[0] 
            all_outputs = []
            # for a in all_attempts:
            a = value["examples_only_response"].split("###Better Prompt###")[-1].strip()
            output = get_output_chat_mode(a.strip() + "\n")
            # print(value["solution"])
            print('-'*200)
            all_outputs.append(output if output is not None else "")
            value["no_response_question"] = a
            value["no_response_answer"] = all_outputs
            dataset[didx] = value
        with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_attempt_{str(attempt+1)}_no_response_answers.json", "w") as f:
            json.dump(dataset, f, indent = 4)
        attempt += 1
        break
        # if(not(len(errors_json)<=0 or attempt==MAX_ATTEMPT)):
        time.sleep(60*1)
    with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_no_response_answers.json", "w") as f:
        json.dump(dataset, f, indent = 4)

        




import tqdm
if __name__ == "__main__":
    dataset_name = "gsm8k"
    dataset_file = "gsm8k_gpt-4_attempt_1_examples_only_no_reason.json"
    prompt_dataset = read_dataset(dataset_name, dataset_file)
    # prompt_dataset = read_MATHS('data/MATHS/maths_easy_to_very_easy.json')
    collect_results(prompt_dataset, "gsm8_gpt-4")
