#!/usr/bin/env python3



import functools
from typing import Dict

from partnr.llm.instruct.utils import build_single_step_prompt
from partnr.planner import LLMPlanner
from partnr.tools.tool import PerceptionTool


# Copied from finetuning code
def get_world_descr(world_graph):
    """
    Builds a string description of the environment from the world graph
    """
    ## house description -- rooms and their furniture list
    if isinstance(world_graph, Dict):
        raise ValueError(
            "Expected world_graph to be a WorldGraph object, not a Dict. Received: {world_graph}"
        )
    furn_room = world_graph.group_furniture_by_room()
    house_info = ""
    for k, v in furn_room.items():
        furn_names = [furn.name for furn in v]
        all_furn = ", ".join(furn_names)
        house_info += k + ": " + all_furn + "\n"

    ## get objects held by the agent
    spot_node = world_graph.get_spot_robot()
    human_node = world_graph.get_human()

    ## locations of objects in the house
    objs_info = ""
    all_objs = world_graph.get_all_objects()
    for obj in all_objs:
        if world_graph.is_object_with_robot(obj):
            objs_info += obj.name + ": " + spot_node.name + "\n"
        elif world_graph.is_object_with_human(obj):
            objs_info += obj.name + ": " + human_node.name + "\n"
        else:
            furn_node = world_graph.find_furniture_for_object(obj)
            furn_name = "unknown" if furn_node is None else furn_node.name
            objs_info += obj.name + ": " + furn_name + "\n"
    return f"Furniture:\n{house_info}\nObjects:\n{objs_info}"


class ThoughtlessLLMPlanner(LLMPlanner):
    """
    This class builds the prompt for the single step, thoughtless (i.e. no chain of thought prompting) format
    used for LLMs finetuned on partnr data.
    """

    def __init__(self, plan_config, env_interface):
        super().__init__(plan_config, env_interface)
        self.stopword = "<end_act>"
        self.end_expression = "Done"

        # cache the actions parser function, this will be partially applied later
        self._actions_parser = self.actions_parser

        self.prompt_history = []
        self.prompt_header = "Solve the given multi-agent planning problem as best as you can. The task assigned to you will be situated in a house and will generally involve navigating to objects, picking and placing them on different receptacles to achieve rearrangement. Below is the detailed description of the actions you can use for solving the task. You can assign them to Agent_0 and/or Agent_1 as required."

    def reset(self):
        self.prompt_history = []
        return super().reset()

    # Ideally the result of this should be cached on reset but the first reset
    # is called before the agents are initialized so there's
    # no conveient hook where we can cache this at the correct time.
    def _get_perception_tool_names(self, agent):
        # get preception tools from the agent description
        perception_tools = []
        for tool_name, tool in agent.tools.items():
            if isinstance(tool, PerceptionTool):
                perception_tools.append(tool_name)
        return perception_tools

    def build_response_grammar(self, world_graph):
        """
        This method builds a grammar that accepts all valid responses based a world graph
        """
        tool_rules = self.build_tool_grammar(world_graph)
        root_role = f'root ::= tool_call "{self.stopword}"'
        return "\n".join([root_role, tool_rules])

    def get_next_action(
        self, instruction, observations, world_graph, verbose: bool = False
    ):
        """
        Gives the next low level action to execute
        """
        assert len(self.agents) == 1
        agent = self.agents[0]
        prompt_string = build_single_step_prompt(
            instruction,
            world_graph[agent.uid],
            str(agent.uid),
            self.env_interface.agent_action_history,
            tools_to_skip=self._get_perception_tool_names(agent),
        )
        self.curr_prompt = prompt_string
        # provide the agent id to the actions parse which is required for this type of planner
        self.actions_parser = functools.partial(self._actions_parser, agent.uid)
        previous_replanning_count = self.replanning_count

        low_level_actions, planner_info, is_done = super().get_next_action(
            instruction, observations, world_graph, verbose
        )
        # Detect that replanning occured
        if self.replanning_count > previous_replanning_count:
            self.prompt_history.append(self.curr_prompt)
        if "responses" in planner_info and planner_info["responses"][agent.uid] != "":
            # This isn't really part of the prompting but this "prompts" key is what is written out to the log on disk
            # and it is useful to have the agent's response in the log
            self.prompt_history.append(
                "Observation: " + planner_info["responses"][agent.uid]
            )
        # Stack all prompts together for logging because future prompts do not necessarily contain the previous prompt
        planner_info["prompts"] = {
            agent.uid: "\n-----------------\n".join(self.prompt_history)
        }
        return low_level_actions, planner_info, is_done
