import pandas as pd
import openai, os, json
import tiktoken

from langchain.chat_models import AzureChatOpenAI
from langchain import LLMChain

PATH = "F:/user-repos/saurasrivastava/prompt_gen"
PATH_DATA = f"{PATH}/outputs"
OUTPUT_PATH = f"{PATH}/outputs"

openai.api_key = ""
openai.api_type = ""
openai.api_base =  ""
openai.api_version = ""


def getLLM():
    return AzureChatOpenAI(deployment_name="gpt-4-32k",
                        model_name="gpt-4-32k", 
                        openai_api_key = "", 
                        openai_api_base = "", 
                        openai_api_version = "")


import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_output(prompt, llm):
    response = llm.predict(prompt)
    return response

def read_dataset(dataset_name="gsm8k.json"):
    jsn = json.load(open(os.path.join(PATH_DATA, dataset_name)))#[:5]
    return jsn



def find_errors_in_dataset(dataset):
    error_indices = []
    for idx, d in enumerate(dataset):
        if (type(dataset) == type({}) and dataset[d].get("final_response") is None) or ((type(dataset) == type([]) and d.get("final_response") is None)):
            print(f"d --- {d}   type (d) {type(d)}")
            error_indices.append(idx if type(d) is not type ("") else d)
            continue
        final_response = dataset[d].get("final_response") if type(dataset)==type({}) else d.get("final_response") 
        if(final_response is None or type(final_response)!=type([]) or len(final_response)<=0):
            print("Here---", final_response)
            error_indices.append(idx if type(d) is not type("") else d)
    print(f"Error in {len(error_indices)} indices")
    return error_indices
        

import time
MAX_ATTEMPT = 10
def collect_results(dataset, dataset_name):
    attempt = 0
    while(True):
        llm = getLLM()
        error_indices = find_errors_in_dataset(dataset)
        # print(error_indices, dataset.keys())
        errors_json = [(dataset[i], i) for i in error_indices]
        # print(errors_json[:3])
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        for idx, value in tqdm.tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # print(value)
            didx, value = value[1], value[0] 
            all_attempts = value["all_attempts"]
            all_outputs = []
            for a in all_attempts:
                output = get_output(a.strip() + "\n", llm)
                all_outputs.append(output if output is not None else "")
            value["final_response"] = all_outputs
            dataset[didx] = value
        with open(OUTPUT_PATH + "/" + dataset_name + f"_attempt_{str(attempt+1)}.json", "w") as f:
            json.dump(dataset, f, indent = 4)
        attempt += 1
        time.sleep(60*5)
    with open(OUTPUT_PATH + "/" + dataset_name + "_final.json", "w") as f:
        json.dump(dataset, f, indent = 4)

        




import tqdm
if __name__ == "__main__":
    llm = getLLM()
    dataset_name = "implicatures_final.json"
    prompt_dataset = read_dataset(dataset_name)
    collect_results(prompt_dataset, dataset_name)
    # outputs, errors_index = [], []
    # for p_idx, prompt in tqdm.tqdm(enumerate(prompt_dataset), desc = f"Working on {dataset_name}", total = len(prompt_dataset)):
    #     # print(prompt)
    #     try:
    #         output_list = []
    #         all_attempts = prompt["all_attempts"] if prompt.get("all_attempts") is not None else prompt["new_response"]["all_attempts"]
    #         for idx, attempt in enumerate(all_attempts):
    #             try:
    #                 output = get_output(attempt, llm)
    #             except Exception as e:
    #                 output = "Error in output" + str(e)
    #                 print(e)
    #             output_list.append({"id":idx, "prompt": attempt, "output":output})
    #     except Exception as e:
    #         output_list = [{"Error": f"Error at Idx: {p_idx}", "Error Desc":str(e)}]
    #     outputs.append(output_list)
    # with open(os.path.join(PATH_DATA, "output_" + dataset_name), "w") as f:
    #     json.dump(outputs, f, indent = 4)
    

'''
errors_dataset = json.load(open("F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\gsm8k_socratic_with_errors.json"))
new_remaining_dataset = json.load(open("F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\gsm8k_socratic_final_final.json"))


for remaining_points in new_remaining_dataset:
    old_index = remaining_points if type(remaining_points) == type(0) else int(remaining_points)
    errors_dataset[old_index] = new_remaining_dataset[remaining_points]


with open("gsm8k_socratic_full_final.json", "w") as f:
    json.dump(errors_dataset, f, indent = 4)
'''


