from utils import *
import datasets
import random
import numpy as np
import openai
import os
import openai
from metric_utils import *
from utils import process_code
def generate_code(example):
    code = forward('codex', prompt=example["question"], input_type="image")
    code = process_code(code)
    return {
        "code": code
    }

from retrieve_tools import retrieve_tool
# from retrieve_tools_ablation import retrieve_tool, construct_vector_library
PROMPT = open(config.codex.prompt).read()

inserted_tools_prompt = """**Note: If necessary, you may also leverage the following tools to directly perform complex operations. 
However, please carefully review the implementation code of the tool functions to determine whether to utilize any of them.
Additionally, consider the appropriate method of passing parameters based on your comprehension of the internal implementation of the tool functions, rather than solely relying on the docstring.**\n"""


def wrap_into_function(func_head, docstring):
    name = func_head.split("(")[0].strip()
    args = ", ".join([arg.split(":")[0].strip() for arg in func_head.split("(")[1].split(")")[0].split(",")])
    return f"def {func_head}:" + "\n" + f"\t'''{docstring}\n\t'''" + "\n" +  f"\treturn {name}({args})\n"

def wrap_into_incontext_sample(query, call):
        code = f"Query: {query}" + "\n" + "def execute_command(image):" + "\n" + "\timage_patch = ImagePatch(image)" + "\n" + f"\treturn {call}\n"
        return code

def generate_code_with_retrieval(example, vector_library, model, tokenizer):

    print()
    print(example["question"])

    example["query"] = example["question"]
    
    retrieval_results = retrieve_tool(example, vector_library, model, tokenizer, 10)
    retrieved_tools = retrieval_results["retrieved_tools"]
    
    
    top_k = 3
    while True:
        try:
            tools = retrieved_tools[:top_k]
            if len(tools) > 0:
                inserted_tools = inserted_tools_prompt + "\n" + "\n\n".join([toolbase["tool"][tool] for tool in tools])
                base_prompt = PROMPT.replace("INSERT_TOOL_HERE", inserted_tools) 
            else:
                base_prompt = PROMPT.replace("INSERT_TOOL_HERE", "") 
                inserted_tools = ""

            code = forward('codex', prompt=example["question"], input_type="image", base_prompt=base_prompt)
            code = process_code(code)
        except openai.error.InvalidRequestError as e: # exceed max token length
            print(e)
            top_k -= 1
            continue
        else:
            print()
            print("\n\n".join([toolbase["tool"][tool] for tool in tools]))
            break

    # write base_prompt to temp.txt
    with open("temp.txt", "w") as f:
        f.write(base_prompt)

    print()
    print(example["question"])
    print(code)
    print("\n"*3)
    
    # exit()
    return {
        "code": code,
        "inserted_tool_prompts": inserted_tools,
        "retrieved_tools": tools
    }

def compute_tool_usage_rate(dataset):
    
    function_names = [extract_function_name(item) for item in toolbase["tool"]]
    count = 0
    for data in dataset:
        retrieved_tools = [function_names[i] for i in data["retrieved_tools"]]
        for tool in retrieved_tools:
            if tool in data["code"]:
                count += 1
    return count / len(dataset)

# TODO: still has bug: identify items like "patch[0]"
def count_args_from_call(call):
    record = [] # record all (), [], {}

    tuples = re.findall(r"\((.*?)\)", call)
    if len(tuples) > 1: # first one is the total args
        for i in range(1, len(tuples)):
            record.append(tuples[i])
    
    lists = re.findall(r"\[(.*?)\]", call)
    for i in range(0, len(lists)):
        record.append(lists[i])

    dicts = re.findall(r"\{(.*?)\}", call)
    for i in range(0, len(dicts)):
        record.append(dicts[i])

    # now replace all comma in record with ";" for protection
    for i, sub_string in enumerate(record):
        call = call.replace(sub_string, sub_string.replace(",", ";"))

    # now count the number of args by splitting with ","
    try:
        args = re.findall(r"\((.*?)\)", call)[0].split(", ")
    except:
        print(call, re.findall(r"\((.*?)\)", call))
        exit()
    return len(args)
    
def remove_extra_functions(code, all_tools, retrieved_tools):
    # extract function name and args from retrieved tools
    try:
        function_names = [extract_function_name(item) for item in [all_tools[i] for i in retrieved_tools]]
    except IndexError:
        print(len(all_tools), retrieved_tools)
        exit()
    num_args = [count_args(item) for item in [all_tools[i] for i in retrieved_tools]]
    # extract function name and args from code
    tool_call = set() # may use multiple tools or each using multiple times
    for line in code.split("\n"):
        for func_name in function_names:
            if func_name in line:
                tool_call.add((func_name, line.strip()))

    tool_call = list(tool_call)
    num_args_in_code = []
    for func_name, call in tool_call:
        arg_list = []     
        if "(" in call and ")" in call: # make sure there are args
            num_args_in_code.append((func_name, count_args_from_call(call)))
    filtered_tools = []
    for i, (func_name, num_arg) in enumerate(zip(function_names, num_args)):
        if (func_name, num_arg) in num_args_in_code:
            filtered_tools.append(retrieved_tools[i]) # list[int]
    return filtered_tools


def validate_results(dataset):
    

    fixed_code = "def execute_command(image, question):\n" + "\n".join(open("./prompts/fixed_code/blip2.prompt", "r").readlines())
    

    blip_save_path = f"./results/eval/{args.eval_dataset}/{args.model}_blip.json" if not args.retrieval else \
                f"./results/eval/{args.eval_dataset}/{args.model}_retrieval_blip.json"

    skip_save_path = f"./results/eval/{args.eval_dataset}/{args.model}_skip.json" if not args.retrieval else \
                f"./results/eval/{args.eval_dataset}/{args.model}_retrieval_skip.json"

    if args.tool_epoch != -1:
        blip_save_path = blip_save_path.replace("_retrieval", f"_{args.tool_epoch}_retrieval")
        skip_save_path = skip_save_path.replace("_retrieval", f"_{args.tool_epoch}_retrieval")

    if args.ablation != "none":
        blip_save_path = blip_save_path.replace("_retrieval", f"_retrieval_no_{args.ablation}")
        skip_save_path = skip_save_path.replace("_retrieval", f"_retrieval_no_{args.ablation}")

    if args.eval_dataset == "vqa_v2":
        blip_save_path.replace(".json", ".csv")
        skip_save_path.replace(".json", ".csv")

    print(blip_save_path, os.path.exists(blip_save_path), skip_save_path, os.path.exists(skip_save_path))
    if os.path.exists(blip_save_path) and os.path.exists(skip_save_path):
        import json
        with open(blip_save_path, "r") as f:
            results = json.load(f)
        
        keys = list(range(len(results["prediction"].keys())))
        print(len(keys))
        # keys = np.load(f"./results/eval/{args.eval_dataset}/selected_indices_{args.model}.npy").tolist()
        # keys = [int(item) for item in open("gpt4_ids.txt", "r").read().split("\n")]
        
        blip_predictions = [results["prediction"][str(i)] for i in keys]

        with open(skip_save_path, "r") as f:
            results = json.load(f)
        skip_predictions = [results["prediction"][str(i)] for i in keys]

        groundtruths = [results["groundtruth"][str(i)] for i in keys]
    else:
        # global flag
        # if not flag:
        #     init_vision_models()
        #     flag = True
        blip_predictions = []
        skip_predictions = []
        groundtruths = []
        for data in tqdm(dataset):
            image = load_image(data["image_path"])
            code = data["code"]

            if args.retrieval:
                # should deduplicate the retrieved tools
                retrieved_tools = remove_extra_functions(code, toolbase["tool"], data["retrieved_tools"])
                retrieved_tools = [toolbase["tool"][i] for i in retrieved_tools]
                explanations = [extract_function_docstring(item)[0] for item in retrieved_tools]
                retrieved_tools = [tool.replace(explanation, "") for tool, explanation in zip(retrieved_tools, explanations)]
                # print(retrieved_tools)
                code = "\n\n".join([
                    *retrieved_tools,
                    code
                ])

            if code is None:
                blip_prediction = ""
                skip_prediction = ""
            elif ("pixel" in code) or ("input(" in code) or ("return" not in code): # dead loop or no return
                print("Error in turbo-generated code.")
                blip_prediction = execute_code(fixed_code, image, data["question"])
                skip_prediction = ""
            else:
                try: # normal cases
                    blip_prediction = execute_code(code, image)
                    skip_prediction = blip_prediction
                    print()
                    print(data["question"])
                    print(blip_prediction, data["answers"])
                    if blip_prediction is None:
                        raise Error
                except:
                    print("Error in turbo-generated code. ")
                    blip_prediction = execute_code(fixed_code, image, data["question"])
                    skip_prediction = ""
                    print()
                    print(data["question"])
                    print(blip_prediction, data["answers"])

            if str(blip_prediction) == "True":
                blip_prediction = "yes"
            elif str(blip_prediction) == "False":
                blip_prediction = "no"

            if str(skip_prediction) == "True":
                skip_prediction = "yes"
            elif str(skip_prediction) == "False":
                skip_prediction = "no"

            blip_predictions.append(blip_prediction)
            skip_predictions.append(skip_prediction)
            groundtruths.append(data["answers"])

        # save to csv, using pandas
        print("save")
        import pandas as pd
        df = pd.DataFrame({"prediction": blip_predictions, "groundtruth": groundtruths})
        if args.eval_dataset == "vqa_v2":
            df.to_csv(blip_save_path)
        else:
            df.to_json(blip_save_path, indent=4)

        df = pd.DataFrame({"prediction": skip_predictions, "groundtruth": groundtruths})
        if args.eval_dataset == "vqa_v2":
            df.to_csv(skip_save_path)
        else:
            df.to_json(skip_save_path, indent=4)
        print("saved")
    predictions = blip_predictions if args.use_blip else skip_predictions
    
    return predictions, groundtruths

def blip_prediction(dataset):
    save_path = f"./results/eval/{args.eval_dataset}/{args.model}.json"
    if os.path.exists(save_path):
        with open(save_path, "r") as f:
            results = json.load(f)
        # keys = list(range(len(results["prediction"].keys())))
        # keys = np.load(f"./results/eval/{args.eval_dataset}/selected_indices_{args.model}.npy").tolist()
        keys = np.load(f"./results/eval/{args.eval_dataset}/selected_indices_gpt4.npy").tolist()
        predictions = [results["prediction"][str(i)] for i in keys]
        groundtruths = [results["groundtruth"][str(i)] for i in keys]
    else:
        from vision_models import BLIPModel
        blip = BLIPModel()
        predictions = []
        
        batch_size = 64
        for i in tqdm(range(len(dataset) // batch_size + 1)):
            batch = dataset.select(range(i*batch_size, min((i+1)*batch_size, len(dataset))))
            images = [load_image(data["image_path"]) for data in batch]
            questions = [data["question"] for data in batch]
            predictions.extend(blip.forward(images, questions, ["qa"]*len(images)))
        groundtruths = dataset["answers"]
        # save to csv, using pandas
        import pandas as pd
        df = pd.DataFrame({"prediction": predictions, "groundtruth": groundtruths})
        df.to_json(f"./results/eval/{args.eval_dataset}/{args.model}.json", indent=4)
    return predictions, groundtruths



if __name__ == "__main__":
    # add args
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_dataset", type=str)
    parser.add_argument("--model", type=str)
    parser.add_argument("--retrieval", action="store_true")
    parser.add_argument("--use_blip", action="store_true")
    parser.add_argument("--tool_epoch", type=int, default=-1)
    parser.add_argument("--ablation", type=str, default="none")

    args = parser.parse_args()

    print(os.getenv('CONFIG_NAMES', None))

    dataset = datasets.load_from_disk(f"./datasets/eval/{args.eval_dataset}")#.select(range(1))
    assert len(dataset) == 1000

    ####################################### Code Gen #######################################


    reserved_columns = ["image_id", "image_path", "question", "answers", "code"]
    if args.retrieval:
        reserved_columns  = reserved_columns +  ["retrieved_tools"] 
    all_columns = list(dataset.features.keys())

    init_vision_models()
    

    if args.retrieval:
        if args.tool_epoch != -1:
            save_path = f"./results/eval/{args.eval_dataset}_{args.model}_retrieval_{args.tool_epoch}.json"
        else:
            save_path = f"./results/eval/{args.eval_dataset}_{args.model}_retrieval.json"
    else:
        save_path = f"./results/eval/{args.eval_dataset}_{args.model}.json"

    if args.ablation != "none":
        save_path = save_path.replace(".json", f"_no_{args.ablation}.json")

    if args.retrieval:
        from transformers import AutoTokenizer, AutoModel
        tokenizer = AutoTokenizer.from_pretrained("/data/private/yuanlifan/.cache/huggingface/transformers/models--princeton-nlp--sup-simcse-roberta-large/snapshots/96d164d9950b72f4ce179cb1eb3414de0910953f")
        model = AutoModel.from_pretrained("/data/private/yuanlifan/.cache/huggingface/transformers/models--princeton-nlp--sup-simcse-roberta-large/snapshots/96d164d9950b72f4ce179cb1eb3414de0910953f").cuda()

        if args.tool_epoch != -1:
            toolbase = datasets.Dataset.from_csv(f'./results/viper/5_deduplicated_tool_{args.tool_epoch}.csv')
        else:
            toolbase = datasets.Dataset.from_csv('./results/viper/5_deduplicated_tool.csv')
        print("toolbase length:", len(toolbase))
        
        function_heads = [extract_function_head(item) for item in toolbase["tool"]]
        function_docstrings = [extract_function_docstring(item)[1] for item in toolbase["tool"]]
        function_queries = toolbase["query"]
        function_calls = toolbase["call"]
        if args.tool_epoch != -1:
            vector_library = torch.load(f"./results/viper/vector_library_{args.tool_epoch}.pt")
        else:
            vector_library = torch.load("./results/viper/vector_library.pt")
        
        if os.path.exists(save_path):
            dataset = datasets.Dataset.from_json(save_path)
        else:
            dataset = dataset.map(lambda x: generate_code_with_retrieval(x, vector_library, model, tokenizer), load_from_cache_file=False).remove_columns(set(all_columns)-set(reserved_columns))
            dataset.to_json(save_path)
    else:
        if os.path.exists(save_path):
            dataset = datasets.Dataset.from_json(save_path)
        else:
            dataset = dataset.map(generate_code, load_from_cache_file=False).remove_columns(set(all_columns)-set(reserved_columns))
            dataset.to_json(save_path)
    
    
    

    ####################################### Validate #######################################

    if args.model != "blip2":
        if args.retrieval:
            if args.tool_epoch != -1:
                result_path = f"./results/eval/{args.eval_dataset}_{args.model}_retrieval_{args.tool_epoch}.json"
            else:
                result_path = f"./results/eval/{args.eval_dataset}_{args.model}_retrieval.json"
        else:
            result_path = f"./results/eval/{args.eval_dataset}_{args.model}.json"
    else:
        result_path = f"./results/eval/{args.eval_dataset}_turbo.json"

    if args.ablation != "none":
        result_path = result_path.replace("_retrieval", f"_retrieval_no_{args.ablation}")

    tool_dir = "viper" if args.ablation != "abstraction" else "viper_ablation"
    if args.tool_epoch != -1:
        toolbase = datasets.Dataset.from_csv(f'./results/{tool_dir}/5_deduplicated_tool_{args.tool_epoch}.csv')
    else:
        toolbase = datasets.Dataset.from_csv(f'./results/{tool_dir}/5_deduplicated_tool.csv')
    print(toolbase)

    
    def process_path(example):
        return {
            "image_path": os.path.join(f"./datasets/eval/{args.eval_dataset}/images", example['image_path'])
        }
    dataset = dataset.map(process_path)

    if args.model == "blip2":
        predictions, groundtruths = blip_prediction(dataset)
    else:
        predictions, groundtruths = validate_results(dataset)

    import pandas as pd
    vqa_acc = 100.00 * compute_vqa_acc(predictions, groundtruths)
    f1 = 100.00 * compute_f1(predictions, groundtruths)
    print(f"VQA accuracy (exact match): {vqa_acc}")
    print(f"F1 score: {f1}")
    if args.retrieval:
        tool_usage_rate = 100.00 * compute_tool_usage_rate(dataset) # (dataset.select(np.load(f"./results/eval/{args.eval_dataset}/selected_indices_{args.model}.npy").tolist()))
        print("Tool Usage Rate:", tool_usage_rate)
    vqa_acc_turbo_eval = 0
    
    # write metrics to csv using pandas
    os.makedirs(f"./results/metrics/{args.eval_dataset}", exist_ok=True)
    bug_process = "blip" if args.use_blip else "skip"
    if not args.retrieval:
        df = pd.DataFrame({"vqa_acc": [vqa_acc], "f1": [f1], "vqa_acc_turbo_eval": [vqa_acc_turbo_eval]})
        save_path = f"./results/metrics/{args.eval_dataset}/{args.model}_{bug_process}.csv"
    else:
        df = pd.DataFrame({"vqa_acc": [vqa_acc], "f1": [f1], "tool_usage_rate": [tool_usage_rate], "vqa_acc_turbo_eval": [vqa_acc_turbo_eval]})
        if args.tool_epoch != -1:
            save_path = f"./results/metrics/{args.eval_dataset}/{args.model}_retrieval_{args.tool_epoch}_{bug_process}.csv"
        else:
            save_path = f"./results/metrics/{args.eval_dataset}/{args.model}_retrieval_{bug_process}.csv"
            if args.ablation != "none":
                save_path = save_path.replace("_retrieval", f"_retrieval_no_{args.ablation}")
    df.to_csv(save_path)
    
    









