import pandas as pd
import openai
import os
import json
import tiktoken
from tqdm import tqdm
from langchain.chat_models import AzureChatOpenAI
from langchain import LLMChain

import traceback
import backoff, requests
import time

#address of project directory
PATH = "F:/user-repos/saurasrivastava/prompt_gen"
#address of directory containing input files for experiments
PATH_DATA = f"{PATH}/data"
#address of directory to store the output for the experiments
OUTPUT_PATH = f"{PATH}/outputs"
#delimeter to break the response from the OpenAI api calls (already have been decided in prompt creation step)
delimiter = "###Better Prompt###"


# module to extract the reason, bad prompt, the good prompt, and the prompt type from the OpenAI response
# if openAI can't create a better prompt this method may throw an exception that's where we set the good prompt to "", and the response as the the response from the call for the debugging purposes.
def extract_reason_response(response):
    # print(response)
    try:
        reason, better_prompt = response.split(delimiter)
        reason, better_prompt = reason.strip(), better_prompt.strip()
        reason, prompt_type = reason.split("###Better Prompt Type###")
        reason, prompt_type = reason.strip(), prompt_type.strip()
    except:
        better_prompt, reason, prompt_type = "", response, ""
    return better_prompt, reason, prompt_type


#the self-improvement stage, where the prompts which are directed towards improving a bad prompt is called again 
#and again. Why? In one iteration, if something has been missed by the algorithm while creating the good prompts, it may cover in the next one.
def call_self_improvement(llm, prompt, candidate_prompt, self_improvement_steps=3):
    all_attempts = []
    for idx in range(self_improvement_steps):
        # all_prompts = prompt.split("\n\n\n###Candidate Prompt### \n\n")
        # all_prompts[-1] = "\n\n\n###Candide Prompt### \n" + input("Please enter your prompt").strip() + "\n###Reason###\n"
        if (idx > 0):
            candidate_prompt = better_prompt.strip()
        new_prompt = prompt + candidate_prompt + "\n###Reason###\n"
        # print(new_prompt)
        # prompt = ''.join(new_prompt)
        response = llm.predict(new_prompt)
        better_prompt, reason, prompt_type = extract_reason_response(response)
        all_attempts.append(better_prompt)
        # print('>>', better_prompt, '<<')
        if (better_prompt.strip() == ""):
            # probably the input was blank or there was nothing important in the input prompt
            return "", response, "", []
    return better_prompt, reason, prompt_type, all_attempts

#the module combining the prompts, and the candidate prompts (the bad prompt) and calling the self-improvement stage
#iteratively to get better prompts.
def get_better_prompt(llm, 
                      prompt, 
                      candidate_prompt, 
                      self_improvement_steps=1):
    """ run a prompt through the validation engine and return a better prompt
    
    """
    # this is a thin wrapper around the call_self_improvement function
    better_prompt, reason, prompt_type, all_attempts = call_self_improvement(
        llm, prompt, candidate_prompt, self_improvement_steps)
    # original prompt very long - don't include in output
    output_dict = {
        "reason": reason,
        "better_prompt": better_prompt,
        "output_prompt_type": prompt_type,
        "all_attempts": all_attempts
    }
    # print(json.dumps(output_dict, indent = 4))
    return output_dict

#FOR BENCHMARKING EXPERIMENTS ONLY
#For benchmarking to test the algorithm on the public datasets, we get the dataset first and then call the better prompt getter for each data point.
def from_dataset(llm, prompt, self_improvement_steps=3, dataset_name = "test"):
    full_output = []
    dataset = prompt_eng_for_dataset()
    for _, line in tqdm(enumerate(dataset), total = len(dataset), desc = f"Processing - {dataset_name}"):
        candidate_prompt = line.strip()
        try:
            output = get_better_prompt(llm, prompt, candidate_prompt, self_improvement_steps=3)
        except Exception as e:
            output = str(e)
        full_output.append(output)
    with open("./outputs/" + dataset_name + ".json", "w") as f:
        json.dump(full_output, f, indent=4)
    return full_output

#FOR BENCHMARKING EXPERIMENTS ONLY
#for GSM8K (all versions) only
def prompt_eng_for_dataset():
    f = "F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\gsm8k_socratic_with_errors.json"
    try:
        all_data = [json.loads(line.strip())['question'] for line in open(f, encoding = "utf-8")]
    except:
        all_data = json.load(open(f, encoding = "utf-8"))
    
    try:
        all_data = all_data["examples"]
    except:
        pass
    return all_data#[:10]

#FOR BENCHMARKING EXPERIMENTS ONLY
#for csv data files
def from_csv(address, prompt_columns, delimiter=","):
    df = pd.read_csv(address, delimiter = delimiter)
    data = df[prompt_columns].values.tolist()
    new_data = {}
    for idx, d in enumerate(data):
        new_data[idx] = {"input": d}
    return new_data

#FOR BENCHMARKING EXPERIMENTS ONLY
#For Big-Bench datasets
def prompt_eng_for_big_bench(address):
    f = address
    all_data = json.load(open(f, encoding = "utf-8"))
    task_prefix = all_data["description"]
    all_data = all_data["examples"]
    for data in all_data:
        data["input"] = task_prefix + ":\n" + data["input"]
        # print(data)
    return all_data

#returns a new copy of LLM everytime it is called.
def getLLM():
    return AzureChatOpenAI(deployment_name="gpt-4-32k",
                      model_name="gpt-4-32k", 
                      openai_api_key = "", 
                      openai_api_base = "", 
                      openai_api_version = "")


#FOR BENCHMARKING EXPERIMENTS ONLY
#while creating prompts for a dataset, some indices may not have been covered maybe due to the failed openAI calls. This function finds those error indices (how? We have specified the format of the output calls. Ensure them the format is correct and if not, those are the indices where openAI threw exceptions).
def find_error_indices(dataset):
    # print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>><<<<<<<<<<<<<<<<<', dataset, '$$$$$$$$')
    error_indices = []
    if(not type(dataset)==type({}) and not type(dataset) == type([])):
        jsn = json.load(open(dataset))
    else:
        jsn = dataset
    for idx, j in enumerate(jsn):
        # print(j)
        # print(j.keys())
        try:
            if(type(dataset)==type({})):
                all_attempts = jsn[j]["all_attempts"]
            else:
                all_attempts = j["all_attempts"] if j.get("all_attempts") is not None else j["new_response"]["all_attempts"]
            assert type (all_attempts) == type([]) and len(all_attempts) == 3
        except Exception as e:
            print(e, j)
            # xyz
            traceback.print_exc()
            error_indices.append(idx if(type(dataset)!=type({})) else j)
    print(f"Error in {len(error_indices)} indices")
    # xz
    return error_indices


#OpenAI method to retry on failed prompt calls
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_better_prompt_caller(llm, prompt, value):
    response = get_better_prompt(llm, prompt, value, self_improvement_steps=3)
    return response

MAX_ATTEMPT = 20
#FOR BENCHMARKING EXPERIMENTS ONLY
#Since, openAI may throw error while calling APIs, this function makes certain retries (defined in the variable MAX_ATTEMPT) to collect prompts for the indices on which the last calls were unsuccessfull.
def fix_till_done(prompt, dataset_name, with_error=True):
    attempt = 0
    dataset = prompt_eng_for_big_bench("F:\\user-repos\\saurasrivastava\\prompt_gen\\data\\implicatures\\data\\task.json")
    # dataset = prompt_eng_for_dataset()
    # dataset = from_csv("F:\\user-repos\\saurasrivastava\\prompt_gen\\data\\chatGPT_prompts\\prompts.csv", "prompt")
    responses, errors_responses = {}, []
    while(True):
        llm = getLLM()
        errors_indices = find_error_indices(dataset)
        # print(errors_indices)
        # testing
        errors_json = [(dataset[i], i) for i in errors_indices] if(attempt > 0  or with_error) else dataset
        # print(len(errors_json), errors_indices)
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        print(errors_json[-1])
        for idx, value in tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # ee
            # print(">>>>>>>>>>>>>>>>>>>>", value)
            try:
                value, idx = value[0], value[1]
            except:
                pass
            response = get_better_prompt_caller(llm, prompt, value["question"] if value.get("question") is not None else value["input"])
            response["question"] = value["question"] if value.get("question") is not None else value["input"]
            responses[idx] = response
        del llm
        with open(OUTPUT_PATH + "/" + dataset_name + f"_attempt_{attempt+1}.json", "w") as f:
            json.dump(responses, f, indent = 4)
        attempt += 1
        dataset = responses
        time.sleep(60*5)
    with open(OUTPUT_PATH + "/" + dataset_name + "_final.json", "w") as f:
        json.dump(responses, f, indent = 4)
    


if __name__ == "__main__":
    openai.api_key = ""
    openai.api_type = ""
    openai.api_base = ""
    openai.api_version = ""

    llm = AzureChatOpenAI(deployment_name="gpt-4-32k",
                          model_name="gpt-4-32k",
                          openai_api_key="",
                          openai_api_base="",
                          openai_api_version="")
    
    
    prompt = [lines for lines in open(
        f'{PATH_DATA}/output_prompt.txt', encoding="utf-8")]
    prompt = ''.join(prompt)
    # prompt_outputs = from_dataset(llm, prompt, self_improvement_steps=3, dataset_name = "strategyQA")
    fix_till_done(prompt, "implicatures")
    # candidate_prompt = input("Please enter your first candidate prompt: ")
    # print(json.dumps(get_better_prompt(llm, prompt,
    #       candidate_prompt, self_improvement_steps=3)))


'''
-->23.5 daN expressed in units of N is 


-->Convert the given value of 23.5 daN (decanewtons) to its equivalent in Newtons (N). Please keep in mind that 1 daN is equal to 10 N. Provide the converted value in the format: "The answer is: [VALUE] N".


########fix gsm8k errors :)

ff = open("F:\\user-repos\\saurasrivastava\\prompt_gen\\data\\gsm8k\\data\\test_socratic.jsonl")
jsns = [json.loads(j.strip()) for j in ff]
j2 = json.load(open("F:\\user-repos\\saurasrivastava\\prompt_gen\\outputs\\gsm8k_socratic_w_errors.json"))

for idx, (output_j, input_j) in enumerate(zip(j2, jsns)):
    if(type(output_j)!=type({})):
        output_j = {}
    output_j["input"] = input_j["question"]
    output_j["answer"] = input_j["answer"]
    j2[idx] = output_j


    
with open("gsm8k_socratic_with_errors.json", "w") as f:
    json.dump(j2, f, indent = 4)

'''


#combine GSM8K better output files
# original_file = [json.loads(x) for x in open("/Volumes/Academic/Projects/PRoMTd/data/gsm8k/test_gsm8k.jsonl")]
# better_file = json.load(open("/Volumes/Academic/Projects/PRoMTd/outputs/gsm8k/gsm8k_with_errors.json"))
# remaining_186 = json.load(open("/Volumes/Academic/Projects/PRoMTd/outputs/gsm8k/gsm8k_186.json"))

# for idx in remaining_186:
#     data = remaining_186[idx]
#     idx = int(idx)
#     del data["question"]
#     better_file[idx].update(data)


# with open("/Volumes/Academic/Projects/PRoMTd/outputs/gsm8k/gsm8k_final.json", "w") as f:
#     json.dump(better_file, f, indent = 4)