import pandas as pd
import openai, os, json
import tiktoken

from langchain.chat_models import AzureChatOpenAI
from langchain import LLMChain

PATH = "F:/user-repos/saurasrivastava/prompt_gen"
PATH_DATA = f"{PATH}/outputs"
DATASET_NAME = "error_output_gsm8k"
PATH_DATASET = f"{PATH}/data"


openai.api_key = ""
openai.api_type = ""
openai.api_base =  ""
openai.api_version = ""


llm = AzureChatOpenAI(deployment_name="gpt-4-32k",
                      model_name="gpt-4-32k", 
                      openai_api_key = "", 
                      openai_api_base = "", 
                      openai_api_version = "")


def find_error_indices(dataset):
    error_indices = []
    jsn = json.load(open(os.path.join(PATH_DATA, dataset + ".json")))
    for idx, j in enumerate(jsn):
        try:
            all_attempts = j["all_attempts"] if j.get("all_attempts") is not None else j["new_response"]["all_attempts"]
            assert len(all_attempts) == 3
        except:
            error_indices.append(idx)
    print(f"Error in {len(error_indices)} indices")
    return error_indices


import backoff, requests
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_better_prompt_caller(llm, prompt, value):
    response = get_better_prompt(llm, prompt, value, self_improvement_steps=3)
    return response


from betterPromptCaller import get_better_prompt
prompt = [lines for lines in open(f'{PATH_DATASET}/output_prompt.txt', encoding = "utf-8")]
prompt = ''.join(prompt)
import tqdm
def fix_errors_in_dataset(dataset):
    errors_responses = []
    errors_indices = find_error_indices(dataset)
    jsn = [json.loads(lines) for lines in open(PATH_DATASET + "/" + "gsm8k" + "/data/" +  "test.jsonl")]
    error_jsn = [jsn[i] for i in errors_indices]
    for idx, value in tqdm.tqdm(enumerate(error_jsn), desc = f"Fixing {dataset}", total = len(error_jsn)):
        response = get_better_prompt_caller(llm, prompt, value["question"])
        errors_responses.append({"idx":error_jsn[idx],"new_response":response})
    with open(os.path.join(PATH_DATA, "error_output_" + dataset + ".json"), "w") as f:
        json.dump(errors_responses, f, indent = 4)


if __name__ == "__main__":
    fix_errors_in_dataset(DATASET_NAME)
