import pandas as pd
import openai
import os
import json
import tiktoken
from tqdm import tqdm
from langchain.chat_models import AzureChatOpenAI
from langchain import LLMChain
from glob import glob
import traceback
import backoff, requests
import time

#address of project directory
PATH = "C:\\Users\\saurasrivastava\\Downloads\\BIG-bench-main\\bigbench\\benchmark_tasks\\list_functions"
#delimeter to break the response from the OpenAI api calls (already have been decided in prompt creation step)
delimiter = "###Better Prompt###"
SHOTS, MODE, API_CALLS = 5, "default", 3

# module to extract the reason, bad prompt, the good prompt, and the prompt type from the OpenAI response
# if openAI can't create a better prompt this method may throw an exception that's where we set the good prompt to "", and the response as the the response from the call for the debugging purposes.
def extract_reason_response(response):
    # print(response)
    try:
        reason, better_prompt = response.split(delimiter)
        reason, better_prompt = reason.strip(), better_prompt.strip()
        reason, prompt_type = reason.split("###Better Prompt Type###")
        reason, prompt_type = reason.strip(), prompt_type.strip()
    except:
        better_prompt, reason, prompt_type = "", response, ""
    return better_prompt, reason, prompt_type


#the self-improvement stage, where the prompts which are directed towards improving a bad prompt is called again 
#and again. Why? In one iteration, if something has been missed by the algorithm while creating the good prompts, it may cover in the next one.
def call_self_improvement(llm, prompt, candidate_prompt, self_improvement_steps=1):
    all_attempts = []
    for idx in range(self_improvement_steps):
        if (idx > 0):
            candidate_prompt = better_prompt.strip()
        new_prompt = prompt + candidate_prompt + "\n###Reason###\n"
        response = llm.predict(new_prompt)
        better_prompt, reason, prompt_type = extract_reason_response(response)
        all_attempts.append(better_prompt)
        if (better_prompt.strip() == ""):
            # probably the input was blank or there was nothing important in the input prompt
            return "", response, "", []
    return better_prompt, reason, prompt_type, all_attempts

#the module combining the prompts, and the candidate prompts (the bad prompt) and calling the self-improvement stage
#iteratively to get better prompts.
def get_better_prompt(llm, 
                      prompt, 
                      candidate_prompt, 
                      self_improvement_steps=1):
    """ run a prompt through the validation engine and return a better prompt
    
    """
    # this is a thin wrapper around the call_self_improvement function
    better_prompt, reason, prompt_type, all_attempts = call_self_improvement(
        llm, prompt, candidate_prompt, self_improvement_steps)
    # original prompt very long - don't include in output
    output_dict = {
        "reason": reason,
        "better_prompt": better_prompt,
        "output_prompt_type": prompt_type,
        "all_attempts": all_attempts
    }
    # print(json.dumps(output_dict, indent = 4))
    return output_dict



#FOR BENCHMARKING EXPERIMENTS ONLY
#For Big-Bench datasets
def prompt_eng_for_list_functions():
    all_directories = [x for x in os.listdir(PATH) if(os.path.isdir(os.path.join(PATH, x)) and x!="results")]
    # print(all_directories)
    all_data = []
    for dir in all_directories:
        task_ID = dir
        all_files = glob(os.path.join(PATH, dir + "\\*.json"))
        for file in all_files:
            # if(file=="task.json"):
            #     continue
            # print(">>>>", file)
            file_name = os.path.split(file)[-1].replace(".json", "")
            try:
                mode, shot, _ = file_name.split("_")
            except:
                name_parts = file_name.split("_")
                # print(file_name)
                if(file_name=="task"):
                    continue
                mode, shot, _ = "_".join(name_parts[:2]), name_parts[2], name_parts[3]
            # print(mode, MODE, shot, SHOTS, mode == MODE and shot == str(SHOTS))
            if(mode == MODE and shot == str(SHOTS)):
                jsn = json.load(open(file))
                jsn.update({"task_ID": task_ID})
                all_data.append(jsn)
    # print(">>>", all_data)
    return all_data#[:5]

#returns a new copy of LLM everytime it is called.
def getLLM():
    return AzureChatOpenAI(deployment_name="gpt-4-32k",
                      model_name="gpt-4-32k", 
                      openai_api_key = "", 
                      openai_api_base = "", 
                      openai_api_version = "")


#FOR BENCHMARKING EXPERIMENTS ONLY
#while creating prompts for a dataset, some indices may not have been covered maybe due to the failed openAI calls. This function finds those error indices (how? We have specified the format of the output calls. Ensure them the format is correct and if not, those are the indices where openAI threw exceptions).
def find_error_indices(dataset):
    # print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>><<<<<<<<<<<<<<<<<', dataset, '$$$$$$$$')
    error_indices = []
    if(not type(dataset)==type({}) and not type(dataset) == type([])):
        jsn = json.load(open(dataset))
    else:
        jsn = dataset
    # print(jsn)
    for idx, j in enumerate(jsn):
        # print(j)
        # print(j.keys())
        try:
            all_attempts = jsn[j]["all_responses"]
            assert type (all_attempts) == type([]) and len(all_attempts) == API_CALLS
        except Exception as e:
            print(e)
            # xyz
            traceback.print_exc()
            error_indices.append(idx if(type(dataset)!=type({})) else j)
    print(f"Error in {len(error_indices)} indices")
    # xz
    return error_indices


#OpenAI method to retry on failed prompt calls
@backoff.on_exception(backoff.expo, Exception, max_tries=50, max_time = 61)
def get_better_prompt_caller(llm, prompt, value):
    response = get_better_prompt(llm, prompt, value, self_improvement_steps=1)
    return response

MAX_ATTEMPT = 20
#FOR BENCHMARKING EXPERIMENTS ONLY
#Since, openAI may throw error while calling APIs, this function makes certain retries (defined in the variable MAX_ATTEMPT) to collect prompts for the indices on which the last calls were unsuccessfull.
def fix_till_done(prompt, dataset_name, with_error=True):
    attempt = 0
    dataset = prompt_eng_for_list_functions()
    responses = {}
    while(True):
        llm = getLLM()
        errors_indices = find_error_indices(dataset)
        # print(errors_indices)
        # testing
        errors_json = [(dataset[i], i) for i in errors_indices] if(attempt > 0  or with_error) else dataset
        # print(len(errors_json), errors_indices)
        if(len(errors_json)<=0 or attempt==MAX_ATTEMPT):
            break
        # print(errors_json[-1])
        for idx, value in tqdm(enumerate(errors_json), desc = f"Attempt {attempt+1}", total = len(errors_json)):
            # ee
            # print(">>>>>>>>>>>>>>>>>>>>", value)
            try:
                value, idx = value[0], value[1]
            except:
                pass
            all_responses = []
            for _ in range(API_CALLS): 
                response = get_better_prompt_caller(llm, prompt, value["train_prompt"])
                if(response.get("better_prompt") is not None and response["better_prompt"].strip()!=""):
                    all_responses.append(response)
            #response["train_prompt"] = value["train_prompt"]
            responses[idx] = {"task_ID":value["task_ID"], "train_prompt":value["train_prompt"], "gold_function": value["gold_function"], "all_responses":all_responses, "validation_examples": value["validation_examples"]}
        del llm
        with open(OUTPUT_PATH + "/PI/" + dataset_name + f"_attempt_{attempt+1}.json", "w") as f:
            json.dump(responses, f, indent = 4)
        attempt += 1
        dataset = responses
        time.sleep(60*5)
    with open(OUTPUT_PATH + "/PI/" + dataset_name + "_final.json", "w") as f:
        json.dump(responses, f, indent = 4)
    
X_PATH =  "F:/user-repos/saurasrivastava/prompt_gen"
PATH_DATA = f"{X_PATH}/data"
#address of directory containing input files for experiments
OUTPUT_PATH = f"{X_PATH}/outputs"
if __name__ == "__main__":
    openai.api_key = ""
    openai.api_type = ""
    openai.api_base = ""
    openai.api_version = ""

    llm = AzureChatOpenAI(deployment_name="gpt-4-32k",
                          model_name="gpt-4-32k",
                          openai_api_key="",
                          openai_api_base="",
                          openai_api_version="")
    
    
    prompt = [lines for lines in open(
        f'{PATH_DATA}/output_prompt.txt', encoding="utf-8")]
    prompt = ''.join(prompt)
    # prompt_outputs = from_dataset(llm, prompt, self_improvement_steps=3, dataset_name = "strategyQA")
    fix_till_done(prompt, f"list_functions_shot_{SHOTS}_mode_{MODE}")
    # candidate_prompt = input("Please enter your first candidate prompt: ")
    # print(json.dumps(get_better_prompt(llm, prompt,
    #       candidate_prompt, self_improvement_steps=3)))
