import pandas as pd
import numpy as np
import datasets
import openai
from utils import *

gpt3_model = GPT3Model()

prompt = """Given a query, convert it into a declaritive command and then a brief and concise imperative instruction. 
Next, infer tool functions that can be used based on the instruction. 
Finally, infer the docstring of the tool functions.

Consider following principles:
1. The instruction should reflect the action to take, rather than emphasizing on specific noun phrases. So you should prioritize using general terms like `object`, `people`, and `action`, and so on, instead of directly saying specific names like `desk`, `american president`, and `stuffed animal`.
2. Use tool function names following the format `verb_noun` with less than five words. Consider utilizing the most frequently used words in function names listed below.
3. The docstring of the tool function should be general and abstract, not specific to the query. Consider utilizing the most frequently used words in function docstrings listed below.
4. End your answer with the format 'The useful functions are: [...]' and 'The final answer is: ...', where '[...]' is a list of useful functions and '...' is the returned answer.
5. The most frequently used words in function names: ['object', 'identify', 'check', 'objects', 'find', 'attribute', 'action', 'location', 'determine', 'existence', 'infer', 'type', 'describe', 'property', 'image', 'purpose', 'activity', 'count', 'interaction', 'state', 'position', 'query', 'based', 'multiple', 'relative', 'another', 'within', 'compare', 'two', 'around', 'category', 'condition', 'event', 'context', 'get', 'main', 'benefits', 'environment', 'status', 'interacting', 'analyze', 'reason', 'verify', 'area', 'types', 'attributes', 'color', 'instances', 'actions', 'reasons', 'associated', 'extract', 'info', 'setting', 'understand', 'best', 'match', 'colors', 'size', 'possible', 'items', 'specific', 'scene', 'information', 'properties', 'part', 'explain', 'material', 'visible', 'process', 'near', 'entity', 'significance', 'situation', 'arrangement', 'provide', 'discuss', 'group', 'topic', 'direction', 'side', 'potential', 'animal', 'feature', 'impact', 'appearance', 'inquire']
6. The most frequently used words in function docstrings: ['specific', 'object', 'identify', 'check', 'image', 'certain', 'given', 'another', 'objects', 'find', 'type', 'existence', 'attribute', 'determine', 'action', 'possible', 'list', 'two', 'infer', 'number', 'based', 'group', 'whether', 'location', 'describe', 'purpose', 'state', 'property', 'activity', 'count', 'instances', 'person', 'information', 'condition', 'position', 'performing', 'context', 'interacting', 'category', 'multiple', 'within', 'status', 'scene', 'event', 'relative', 'query', 'associated', 'one', 'main', 'benefits', 'types', 'attributes', 'compare', 'performed', 'presence', 'behind', 'entity', 'best', 'environment', 'reason', 'interaction', 'located', 'set', 'potential']

Query: {}
Let's think step by step:
"""


def planning(query):    

    messages = [
        # gqa
        {"role": "user", "content": prompt.format("What is visible on the desk?")},
        {"role": "assistant", "content": "First, the corresponding declaritive command of the query is 'Identify the visible objects on the desk'.\n\n"\
                                        "After abstracting, the general instruction should be 'Identify the objects on the specific surface.'.\n\n"\
                                        "So considering the naming rules of tool functions, the relevant and useful functions could be named as 'identify_objects' or 'identify_objects_on_surface'.\n\n"\
                                        "Finally, we can infer that the docstring of the tool function could be 'Identify the objects on the specified surface.'.\n\n"\
                                        "The useful functions are: ['identify_objects', 'identify_objects_on_surface'].\n\n"\
                                        "The final answer is: Identify the objects on the specified surface."},

        # okvqa
        {"role": "user", "content": prompt.format("Which american president is most associated with the stuffed animal seen here?")},
        {"role": "assistant", "content": "First, the corresponding declaritive command of the query is 'Search the american president most associated with the stuffed animal seen here'.\n\n"\
                                        "After abstracting, the general instruction should be 'Search people most associated with the specific object.'.\n\n"\
                                        "So considering the naming rules of tool functions, the relevant and useful functions could be named as 'search_people_associated_with_object'.\n\n"\
                                        "Finally, we can infer that the docstring of the tool function could be 'Search for people most associated with the specific object.'.\n\n"\
                                        "The useful functions are: ['search_people_associated_with_object'].\n\n"\
                                        "The final answer is: Search for people most associated with a specific object."},
        {"role": "user", "content": prompt.format(query)},
    ]
    response = gpt3_model.query_with_message(messages, max_tokens=200)
    print(response)
    plans = [query, response.split("The final answer is: ")[1].strip()]

    print(plans)

    try:
        expected_tools = eval(response.split("\n\n")[-2].split("The useful functions are: ")[1].strip("."))
    except:
        expected_tools = eval(response.split("\n\n")[-2].split("The useful function is: ")[1].strip("."))
    print(expected_tools)

    return plans, expected_tools


def match_plan_from_single_perspective(plan_embeddings, tool_embeddings, k=3): # k: number of tools to retrieve for each sub-task from each perspective
    tool_list = []
    for plan_embedding in plan_embeddings:
        # compute cos sim between plan and query
        plan_embedding = plan_embedding.unsqueeze(0)
        sim = torch.nn.functional.cosine_similarity(plan_embedding.unsqueeze(1), tool_embeddings.unsqueeze(0), dim=2)
        topk = torch.topk(sim, k=k, dim=1).indices.squeeze(0).tolist()
        tool_list.append(topk)
    return tool_list


def retrieve_tool(example, vector_library, model, tokenizer, k=10): # k: number of tools to retrieve for each sub-task
    # decompose the query into sub-tasks
    plans, expected_tools = planning(example['query'])
    plan_embeddings = compute_simcse(model, tokenizer, plans)
    expected_tool_embeddings = compute_simcse(model, tokenizer, expected_tools)
    # expected_tool_embeddings = compute_simcse(model, tokenizer, [item for sublist in expected_tools for item in sublist])
    print(0)
    # match plan with tools from different perspectives
    tool_by_explanation = match_plan_from_single_perspective(plan_embeddings[1:], vector_library["explanation_embedding"], k=10)
    tool_by_name = match_plan_from_single_perspective(expected_tool_embeddings, vector_library["name_embedding"], k=5)
    tool_by_query = match_plan_from_single_perspective(plan_embeddings[0].unsqueeze(0), vector_library["query_embedding"], k=10)
    tool_list = []
    print(1)
    counter = Counter([ 
                        *[item for sublist in tool_by_explanation for item in sublist], # k_1*len(plans)
                        *[item for sublist in tool_by_name for item in sublist], # k_1*len(plans)
                        *[item for sublist in tool_by_query for item in sublist], # k_1*1
                    ])
    top_k = counter.most_common(k) # k_2
    tool_list.extend([tool for (tool, count) in top_k if count >= 2]) # must at least have 2 votes

    tool_list = list(set(tool_list))
    return {"instruction": plans[1], "expected_tools": expected_tools, "retrieved_tools": tool_list}

