#!/usr/bin/env python3



from typing import Dict

from partnr.planner.llm_planner import LLMPlanner


class CentralizedLLMPlanner(LLMPlanner):
    def __init__(self, plan_config, env_interface):
        # Set the planner config

        # Initialize LLM
        super().__init__(plan_config, env_interface)

    def prepare_prompt(self, instruction, world_graph):
        """Prepare the prompt for the LLM, both by adding the input
        and the agent descriptions"""
        if isinstance(world_graph, Dict):
            raise ValueError(
                "Expected world_graph to be a WorldGraph object, not a Dict. Received: {world_graph}"
            )

        params = {
            "input": instruction,
            "tool_list": self.tool_list,
            "world_graph": world_graph,
        }

        ## house description
        furn_room = world_graph.group_furniture_by_room()
        house_info = ""
        for k, v in furn_room.items():
            furn_names = [furn.name for furn in v]
            all_furn = ", ".join(furn_names)
            house_info += k + ": " + all_furn + "\n"

        ## objects in the house
        objs_info = ""
        all_objs = world_graph.get_all_objects()
        for obj in all_objs:
            fur_object = world_graph.find_furniture_for_object(obj).name
            objs_info += obj.name + ": " + fur_object + "\n"

        if "{agent_list}" in self.prompt:
            params["agent_list"] = self.agent_list

        # TODO: We are inserting tool description of one of the agents
        # Assuming the other agent has same tools
        # Tool descriptions were used because gave better performance with llama2
        if "{tool_descriptions}" in self.prompt:
            params["tool_descriptions"] = self.agents[0].tool_descriptions
        if "{agent_descriptions}" in self.prompt:
            params["agent_descriptions"] = self.agent_descriptions
        if "{house_description}" in self.prompt:
            params["house_description"] = house_info
        if "{all_objects}" in self.prompt:
            params["all_objects"] = objs_info
        if "{system_tag}" in self.prompt:
            params["system_tag"] = self.planner_config.llm.system_tag
        if "{user_tag}" in self.prompt:
            params["user_tag"] = self.planner_config.llm.user_tag
        if "{assistant_tag}" in self.prompt:
            params["assistant_tag"] = self.planner_config.llm.assistant_tag
        if "{eot_tag}" in self.prompt:
            params["eot_tag"] = self.planner_config.llm.eot_tag

        # print(self.prompt.format(**params))
        return self.prompt.format(**params), params
