import pandas as pd
import openai, os, json
import tiktoken
PATH = "."
PATH_DATA = f"{PATH}/outputs"
OUTPUT_PATH = f"{PATH}/outputs"
MODE_DICT = {"COMPLETION": {"davinci":"text-davinci-003"}, "CHAT_MODE":{"turbo":"gpt-3.5-turbo", "gpt4": "gpt-4"}}
MODE = "CHAT_MODE"
MODEL = "turbo"
openai.api_key = "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"





import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_chat_mode(prompt):
    print(prompt)
    response = openai.ChatCompletion.create(
        model = MODE_DICT[MODE][MODEL],
        messages = [{"role": "user", "content": prompt}],
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    print(response["choices"][0]["message"]["content"])
    return response["choices"][0]["message"]["content"]

@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output_completion_mode(prompt):
    print(prompt)
    response = openai.Completion.create(
        model = MODE_DICT[MODE][MODEL],
        prompt = prompt,
        temperature = 1.0,
        top_p = 1,
        frequency_penalty = 0.0,
        presence_penalty = 0.0
        )
    print(response["choices"][0]["text"].strip())
    return response["choices"][0]["text"].strip()

def read_dataset(dataset_name, dataset_file):
    jsn = json.load(open(os.path.join(PATH_DATA, dataset_name, dataset_file)))#[:5]
    return jsn



def find_errors_in_dataset(dataset):
    error_indices = []
    for idx, d in enumerate(dataset):
        if (type(dataset) == type({}) and dataset[d].get("final_response") is None) or ((type(dataset) == type([]) and d.get("final_response") is None)):
            print(f"d --- {d}   type (d) {type(d)}")
            error_indices.append(idx if type(d) is not type ("") else d)
            continue
        final_response = dataset[d].get("final_response") if type(dataset)==type({}) else d.get("final_response") 
        if(final_response is None or type(final_response)!=type([]) or len(final_response)<=0):
            print("Here---", final_response)
            error_indices.append(idx if type(d) is not type("") else d)
    print('-'*100)
    print(f"Error in {len(error_indices)} indices")
    print('-'*100)
    return error_indices#[:2]
        

import time
MAX_ATTEMPT = 10
def collect_results(dataset, dataset_name):
    attempt = 0
    while(True):
        error_indices = find_errors_in_dataset(dataset)
        # print(error_indices, dataset.keys())
        errors_json = [(dataset[i], i) for i in error_indices]
        # print(errors_json[:3])
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        for idx, value in tqdm.tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # print(value)
            didx, value = value[1], value[0] 
            all_attempts = value["all_attempts"]
            all_outputs = []
            for a in all_attempts:
                # print(a)
                # xyz
                output = get_output_chat_mode(a.strip() + "\n")
                all_outputs.append(output if output is not None else "")
            value["final_response_icl_test"] = all_outputs
            dataset[didx] = value
        with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_attempt_{str(attempt+1)}_icl_test.json", "w") as f:
            json.dump(dataset, f, indent = 4)
        attempt += 1
        # if(not(len(errors_json)<=0 or attempt==MAX_ATTEMPT)):
        time.sleep(60*1)
    with open(OUTPUT_PATH + "/" + dataset_name + f"_{MODE_DICT[MODE][MODEL]}_icl_test_final.json", "w") as f:
        json.dump(dataset, f, indent = 4)

        




import tqdm
if __name__ == "__main__":
    dataset_name = "gsm8k"
    dataset_file = "gsm8k_icl_only_final_icl_test.json"
    prompt_dataset = read_dataset(dataset_name, dataset_file)
    collect_results(prompt_dataset, dataset_name)
    # outputs, errors_index = [], []
    # for p_idx, prompt in tqdm.tqdm(enumerate(prompt_dataset), desc = f"Working on {dataset_name}", total = len(prompt_dataset)):
    #     # print(prompt)
    #     try:
    #         output_list = []
    #         all_attempts = prompt["all_attempts"] if prompt.get("all_attempts") is not None else prompt["new_response"]["all_attempts"]
    #         for idx, attempt in enumerate(all_attempts):
    #             try:
    #                 output = get_output(attempt, llm)
    #             except Exception as e:
    #                 output = "Error in output" + str(e)
    #                 print(e)
    #             output_list.append({"id":idx, "prompt": attempt, "output":output})
    #     except Exception as e:
    #         output_list = [{"Error": f"Error at Idx: {p_idx}", "Error Desc":str(e)}]
    #     outputs.append(output_list)
    # with open(os.path.join(PATH_DATA, "output_" + dataset_name), "w") as f:
    #     json.dump(outputs, f, indent = 4)
    

'''
errors_dataset = json.load(open("F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\gsm8k_socratic_with_errors.json"))
new_remaining_dataset = json.load(open("F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\gsm8k_socratic_final_final.json"))


for remaining_points in new_remaining_dataset:
    old_index = remaining_points if type(remaining_points) == type(0) else int(remaining_points)
    errors_dataset[old_index] = new_remaining_dataset[remaining_points]


with open("gsm8k_socratic_full_final.json", "w") as f:
    json.dump(errors_dataset, f, indent = 4)
'''





# sports_dataset = json.load(open("/Volumes/Academic/Projects/PRoMTd/outputs/sports/sports.json"))

# for keys in sports_dataset:
#     del sports_dataset[keys]["final_response"]

# with open("/Volumes/Academic/Projects/PRoMTd/outputs/sports/sports_only_better_prompts.json", "w") as f:
#     json.dump(sports_dataset, f, indent = 4)