from partnr.llm.instruct.utils import get_objects_descr
from partnr.planner import LLMPlanner
from partnr.utils.grammar import FREE_TEXT


class ZeroShotReactPlanner(LLMPlanner):
    """
    This class builds the prompt for the single step, zero shot single agent planner format
    used for LLMs finetuned on  partnr data.
    """

    def __init__(self, plan_config, env_interface):
        super().__init__(plan_config, env_interface)

    def build_response_grammar(self, world_graph):
        delimiter = "\\n"
        tool_rules = self.build_tool_grammar(world_graph)

        root_rule = (
            f'root ::= {FREE_TEXT} "{delimiter}" tool_call "{delimiter}Assigned!"'
        )

        return "\n".join([root_rule, tool_rules])

    def _add_responses_to_prompt(self, responses):
        if self.planner_config.objects_response:
            assert len(self.agents) == 1
            agent = self.agents[0]
            result = ""
            world_graph = self.env_interface.world_graph[agent.uid]
            if responses[agent.uid] != "":
                response_format = (
                    "{user_tag}Result: {result}\nObjects: {objects}{eot_tag}"
                )
                objects = get_objects_descr(
                    world_graph,
                    agent.uid,
                    include_room_name=True,
                    add_state_info=self.planner_config.objects_response_include_states,
                )
                result = response_format.format(
                    result=responses[agent.uid],
                    objects=objects,
                    user_tag=self.planner_config.llm.user_tag,
                    eot_tag=self.planner_config.llm.eot_tag,
                )
                self.curr_prompt += result + self.planner_config.llm.assistant_tag
                print(result + self.planner_config.llm.assistant_tag, end="")
                self.trace += result + self.planner_config.llm.assistant_tag
        else:
            result = super()._add_responses_to_prompt(responses)
        return result
