#!/usr/bin/env python3



import random
import time
from typing import Any, Dict, Union

from partnr.planner.llm_planner import LLMPlanner


class StateBasedLLMPlanner(LLMPlanner):
    def __init__(self, plan_config, env_interface):
        # Set the planner config

        # Initialize LLM for planner and summarizer
        super().__init__(plan_config, env_interface)
        self.__initialize_state_summarizer()
        self.summary = "Agent 0: Agent 0 hasn't taken any actions for the task yet and is currently waiting.\nAgent 1: Agent 1 hasn't taken any actions for the task yet and is currently waiting."

    def reset(self):
        super().reset()

    def __initialize_state_summarizer(self):
        """
        This method instantiates prompts for summarizer using the existing planner llm
        """
        # Setup the LLM parameters
        self.prompt_summarizer = self.instruct.prompt_summarizer
        self.stopword_summarizer = self.instruct.stopword_summarizer
        self.end_expression_summarizer = self.instruct.end_expression_summarizer

        return

    def prepare_summarizer_prompt(
        self, instruction, world_graph, agents_activity_history
    ):
        """Prepare the prompt for the summarizer LLM"""
        if isinstance(world_graph, Dict):
            raise ValueError(
                "Expected world_graph to be a WorldGraph object, not a Dict. Received: {world_graph}"
            )

        params = {
            "input": instruction,
            "world_graph": world_graph,
        }
        if "{agent0_actions}" in self.prompt_summarizer:
            params["agent0_actions"] = agents_activity_history[0]
        if "{agent1_actions}" in self.prompt_summarizer:
            params["agent1_actions"] = agents_activity_history[1]
        if "{system_tag}" in self.prompt_summarizer:
            params["system_tag"] = self.planner_config.llm.system_tag
        if "{user_tag}" in self.prompt_summarizer:
            params["user_tag"] = self.planner_config.llm.user_tag
        if "{assistant_tag}" in self.prompt_summarizer:
            params["assistant_tag"] = self.planner_config.llm.assistant_tag
        if "{eot_tag}" in self.prompt_summarizer:
            params["eot_tag"] = self.planner_config.llm.eot_tag
        return self.prompt_summarizer.format(**params), params

    def print_current_state(self, world_graph, agent_prev_obs):
        if isinstance(world_graph, Dict):
            raise ValueError(
                "Expected world_graph to be a WorldGraph object, not a Dict. Received: {world_graph}"
            )
        ## get objects held by the agent
        spot_node = world_graph.get_spot_robot()
        human_node = world_graph.get_human()

        ## locations of objects in the house
        objs_info = ""
        all_objs = world_graph.get_all_objects()
        for obj in all_objs:
            if world_graph.is_object_with_robot(obj):
                objs_info += obj.name + ": " + spot_node.name + "\n"
            elif world_graph.is_object_with_human(obj):
                objs_info += obj.name + ": " + human_node.name + "\n"
            else:
                fur_object = world_graph.find_furniture_for_object(obj)
                if fur_object:
                    fur_object_name = fur_object.name
                    objs_info += obj.name + ": " + fur_object_name + "\n"

        state_str = f"\nObjects in the house:\n{objs_info}\n"
        state_str += (
            "Agent's observations of the last executed action (if available):\n"
        )
        if agent_prev_obs:
            for agent in self.agents:
                if str(agent.uid) in agent_prev_obs:
                    state_str += f"Agent_{agent.uid}_Observation: {agent_prev_obs[str(agent.uid)]}\n"
        print(state_str)
        return state_str

    def prepare_prompt(self, instruction, world_graph):
        if isinstance(world_graph, Dict):
            raise ValueError(
                "Expected world_graph to be a WorldGraph object, not a Dict. Received: {world_graph}"
            )
        """Prepare the prompt for the LLM, both by adding the input
        and the agent descriptions"""
        params = {
            "input": instruction,
            "tool_list": self.tool_list,
            "world_graph": world_graph,
            "id": self.agents[0].uid,
        }

        ## house description -- rooms and their furniture list
        furn_room = world_graph.group_furniture_by_room()
        house_info = ""
        for k, v in furn_room.items():
            furn_names = [furn.name for furn in v]
            all_furn = ", ".join(furn_names)
            house_info += k + ": " + all_furn + "\n"

        ## get objects held by the agent
        spot_node = world_graph.get_spot_robot()
        human_node = world_graph.get_human()

        ## locations of objects in the house
        objs_info = ""
        all_objs = world_graph.get_all_objects()
        for obj in all_objs:
            if world_graph.is_object_with_robot(obj):
                objs_info += obj.name + ": " + spot_node.name + "\n"
            elif world_graph.is_object_with_human(obj):
                objs_info += obj.name + ": " + human_node.name + "\n"
            else:
                fur_object = world_graph.find_furniture_for_object(obj)
                if fur_object:
                    fur_object_name = fur_object.name
                    objs_info += obj.name + ": " + fur_object_name + "\n"

        # We motify the prompt if we want to use RAG and the prompt has not
        # been motified
        if self.rag is not None and self.swap_instrction:
            # Select the example based on the instruction
            _, index = self.rag.retrieve_top_k_given_query(
                instruction, top_k=1, agent_id=self._agents[0].uid
            )
            index = index[0]
            target_str = self.planner_config.rag_data_target_text + "\n{eot_tag}\n"
            last_index = self.prompt.rfind(target_str)
            assert (
                last_index != -1
            ), "Cannnot find the target string to insert the RAG prompt. It is often due to prompt or format changes"
            # Add example starting from ID
            example_id = self.planner_config.start_example_id
            add_prompt = ""
            # Shuffle the example traces
            _copy_trace = self.rag.data_dict[index]["trace"].copy()
            random.shuffle(_copy_trace)
            for i, add_trace in enumerate(_copy_trace):
                if i == self.planner_config.max_number_of_rag_example_added:
                    break
                add_prompt += (
                    "\n"
                    + self.planner_config.llm.user_tag
                    + f"Example {example_id}:\n"
                    + add_trace
                    + "\n"
                    + self.planner_config.llm.eot_tag
                )
                example_id += 1

            self.prompt = (
                self.prompt[0 : last_index + len(target_str)]
                + add_prompt
                + self.prompt[last_index + len(target_str) :]
            )
            self.swap_instrction = False

        if "{agent_list}" in self.prompt:
            params["agent_list"] = self.agent_list

        # TODO: We are inserting tool description of one of the agents
        # Assuming the other agent has same tools
        # Tool descriptions were used because gave better performance with llama2
        if "{tool_descriptions}" in self.prompt:
            params["tool_descriptions"] = self.agents[0].tool_descriptions
        if "{agent_descriptions}" in self.prompt:
            params["agent_descriptions"] = self.agent_descriptions
        if "{house_description}" in self.prompt:
            params["house_description"] = house_info
        if "{all_objects}" in self.prompt:
            params["all_objects"] = objs_info
        if "{summary}" in self.prompt:
            params["summary"] = self.summary
        if "{agent_obs}" in self.prompt:
            agent_current_obs_str = ""
            for k, v in self.latest_agent_response.items():
                agent_current_obs_str += f"""Agent_{k}_Observation: {v}\n"""
            params["agent_obs"] = agent_current_obs_str
        if "{agent0_obs}" in self.prompt:
            if "0" in self.latest_agent_response:
                params[
                    "agent0_obs"
                ] = f"""Agent_0_Observation: {self.latest_agent_response['0']}\n"""
            else:
                params["agent0_obs"] = ""
        if "{agent1_obs}" in self.prompt:
            if "1" in self.latest_agent_response:
                params[
                    "agent1_obs"
                ] = f"""Agent_1_Observation: {self.latest_agent_response['1']}\n"""
            else:
                params["agent1_obs"] = ""

        if "{system_tag}" in self.prompt:
            params["system_tag"] = self.planner_config.llm.system_tag
        if "{user_tag}" in self.prompt:
            params["user_tag"] = self.planner_config.llm.user_tag
        if "{assistant_tag}" in self.prompt:
            params["assistant_tag"] = self.planner_config.llm.assistant_tag
        if "{eot_tag}" in self.prompt:
            params["eot_tag"] = self.planner_config.llm.eot_tag

        # print(self.prompt.format(**params))
        return self.prompt.format(**params), params

    def get_next_action(
        self, instruction, observations, world_graph, verbose: bool = False
    ):
        """
        Gives the next low level action to execute
        """
        planner_info: Dict[str, Union[Any, str]] = {
            "replanned": {agent.uid: False for agent in self.agents}
        }

        # Early return if planner is already done
        if self.is_done:
            planner_info = {
                "replanned": {agent.uid: False for agent in self.agents},
                "traces": {agent.uid: self.trace for agent in self.agents},
                "prompts": {agent.uid: self.curr_prompt for agent in self.agents},
                "replanning_count": {
                    agent.uid: self.replanning_count for agent in self.agents
                },
                "replan_required": {
                    agent.uid: self.replan_required for agent in self.agents
                },
                "is_done": {agent.uid: self.is_done for agent in self.agents},
            }
            return {}, planner_info, self.is_done

        print_str = ""
        self.is_done = False

        # saving additional info for logging in trace
        if self.trace == "":
            self.trace += f"Task:{instruction}\n"
            ## house description -- rooms and their furniture list
            furn_room = world_graph[self.agents[0].uid].group_furniture_by_room()
            house_info = ""
            for k, v in furn_room.items():
                furn_names = [furn.name for furn in v]
                all_furn = ", ".join(furn_names)
                house_info += k + ": " + all_furn + "\n"
            self.trace += f"House description:{house_info}\n"

        if self.replan_required:
            planner_info["replanned"] = {agent.uid: True for agent in self.agents}
            if verbose:
                # calculate the total time of response generation
                start_time = time.time()

            # Obtain updated state and summary of task progress
            agents_activity_history = self.env_interface.agent_state_history
            # saving state info in trace
            print_str += self.print_current_state(
                world_graph[self.agents[0].uid], self.latest_agent_response
            )

            self.prompt_summary, self.params_prompt = self.prepare_summarizer_prompt(
                instruction, world_graph[self.agents[0].uid], agents_activity_history
            )
            self.summary = self.llm.generate(
                self.prompt_summary, self.stopword_summarizer
            )

            # Format the response
            # This removes extra text followed by end expression when needed.
            self.summary = self.format_response(self.summary, self.stopword_summarizer)

            # Planner prompt
            self.curr_prompt, self.params = self.prepare_prompt(
                instruction, world_graph[self.agents[0].uid]
            )
            # Generate response
            if self.planner_config.get("constrained_generation", False):
                llm_response = self.llm.generate(
                    self.curr_prompt,
                    self.stopword,
                    generation_args={
                        "grammar_definition": self.build_response_grammar(
                            world_graph[self._agents[0].uid]
                        )
                    },
                )
            else:
                llm_response = self.llm.generate(self.curr_prompt, self.stopword)

            # Format the response
            # This removes extra text followed by end expression when needed.
            llm_response = self.format_response(llm_response, self.stopword)

            if verbose:
                total_time = time.time() - start_time
                print(
                    f"Time taken for LLM response generation: {total_time}; replanning_count: {self.replanning_count}"
                )

            # Parse thought from the response
            thought = self.parse_thought(llm_response)

            # Update prompt with the first response
            print_str += f"{self.summary}\n"
            print_str += f"Thought: {llm_response}\n{self.stopword}\n"

            # Check if the planner should stop
            # Stop if the replanning count exceed a certain threshold
            # or end expression is found in llm response
            # This is helpful to break infinite planning loop.
            self.is_done = (self.end_expression in llm_response) or (
                self.replanning_count == self.planner_config.replanning_threshold
            )
            # Increment the llm call counter on every replan
            # doesn't get incremented before comparison as first "replan" is technically
            # the first required plan
            self.replanning_count += 1

            # Early return if stop is required
            if self.is_done:
                planner_info = {
                    "replanned": {agent.uid: True for agent in self.agents},
                    "print": print_str,
                    "traces": {agent.uid: self.trace for agent in self.agents},
                    "prompts": {agent.uid: self.curr_prompt for agent in self.agents},
                    "replanning_count": {
                        agent.uid: self.replanning_count for agent in self.agents
                    },
                    "replan_required": {
                        agent.uid: self.replan_required for agent in self.agents
                    },
                    "is_done": {agent.uid: self.is_done for agent in self.agents},
                    "thought": {agent.uid: thought for agent in self.agents},
                    "high_level_actions": {
                        agent.uid: ("Done", None, None) for agent in self.agents
                    },
                }
                return {}, planner_info, self.is_done

            # Parse high level action directives from llm response
            high_level_actions = self.actions_parser(
                self.agents, llm_response, self.params
            )
            print(f"\n\nExecuting: {high_level_actions}\n\n")

            # Get low level actions and/or responses
            low_level_actions, responses = self.process_high_level_actions(
                high_level_actions, observations
            )

            # Store last executed high level action
            self.last_high_level_actions = high_level_actions
        else:
            # Set thought to None
            thought = None

            # Get low level actions and/or responses using last high level actions
            low_level_actions, responses = self.process_high_level_actions(
                self.last_high_level_actions, observations
            )

        # Log if replanning was done or not before overwriting the value
        planner_info["replan_required"] = {
            agent.uid: self.replan_required for agent in self.agents
        }

        # Check if replanning is required
        # Replanning is required when any of the actions being executed
        # have a response indicating success or failure (and the reason)
        self.replan_required = any(responses.values())

        # Add responses to the print and prompt
        for agent_uid in sorted(responses.keys()):
            # If the response for a given agent is valid, add to the prompt and printout
            if responses[agent_uid]:
                # Update print string
                print_str += (
                    f"""Agent_{agent_uid}_Observation:{responses[agent_uid]}\n"""
                )

            # If the response is empty then indicate the action is still in progress
            # only when replanning was required
            elif self.replan_required:
                responses[
                    agent_uid
                ] = f"Action {self.last_high_level_actions[agent_uid][0]}[{self.last_high_level_actions[agent_uid][1]}] is still in progress."

                # Update print string
                print_str += (
                    f"""Agent_{agent_uid}_Observation:{responses[agent_uid]}\n"""
                )

            # save agent observations to get feedback on skill execution
            self.latest_agent_response[agent_uid] = responses[agent_uid]

        # Update planner info
        self.trace += print_str
        planner_info["responses"] = responses
        planner_info["thought"] = {agent.uid: thought for agent in self.agents}
        planner_info["print"] = print_str
        planner_info["high_level_actions"] = self.last_high_level_actions
        planner_info["traces"] = {agent.uid: self.trace for agent in self.agents}
        planner_info["prompts"] = {agent.uid: self.curr_prompt for agent in self.agents}
        planner_info["replanning_count"] = {
            agent.uid: self.replanning_count for agent in self.agents
        }
        planner_info["agent_states"] = self.get_last_agent_states()
        planner_info["agent_positions"] = self.get_last_agent_positions()
        planner_info["agent_collisions"] = self.get_agent_collisions()
        planner_info["is_done"] = {agent.uid: self.is_done for agent in self.agents}
        return low_level_actions, planner_info, self.is_done


class CentralizedStateBasedLLMPlanner(StateBasedLLMPlanner):
    def __init__(self, plan_config, env_interface):
        # Set the planner config

        # Initialize LLM for planner and summarizer
        super().__init__(plan_config, env_interface)

    def get_next_action(self, instruction, observations, world_graph):
        """
        Gives the next low level action to execute
        """
        # Early return if planner is already done
        if self.is_done:
            planner_info = {
                "replanned": {agent.uid: False for agent in self.agents},
                "traces": {agent.uid: self.trace for agent in self.agents},
                "prompts": {agent.uid: self.curr_prompt for agent in self.agents},
                "replanning_count": {
                    agent.uid: self.replanning_count for agent in self.agents
                },
                "replan_required": {
                    agent.uid: self.replan_required for agent in self.agents
                },
                "is_done": {agent.uid: self.is_done for agent in self.agents},
            }
            return {}, planner_info, self.is_done

        planner_info = {"replanned": {agent.uid: False for agent in self.agents}}
        print_str = ""
        self.is_done = False

        if self.replan_required:
            planner_info["replanned"] = {agent.uid: True for agent in self.agents}
            # Obtain updated summary of task progress
            agents_activity_history = self.env_interface.agent_state_history
            self.print_current_state(world_graph[0], self.latest_agent_response)

            self.prompt_summary, self.params_prompt = self.prepare_summarizer_prompt(
                instruction, world_graph[0], agents_activity_history
            )
            self.summary = self.llm.generate(
                self.prompt_summary, self.stopword_summarizer
            )

            # Format the response
            # This removes extra text followed by end expression when needed.
            self.summary = self.format_response(self.summary, self.stopword_summarizer)

            # Current prompt update
            self.curr_prompt, self.params = self.prepare_prompt(
                instruction, world_graph[0]
            )

            # Generate response
            if self.planner_config.get("constrained_generation", False):
                llm_response = self.llm.generate(
                    self.curr_prompt,
                    self.stopword,
                    generation_args={
                        "grammar_definition": self.build_response_grammar(
                            world_graph[self._agents[0].uid]
                        )
                    },
                )
            else:
                llm_response = self.llm.generate(self.curr_prompt, self.stopword)
            # Format the response
            # This removes extra text followed by end expression when needed.
            llm_response = self.format_response(llm_response, self.stopword)

            # Parse thought from the response
            thought = self.parse_thought(llm_response)

            # Update prompt with the first response
            print_str += f"{self.summary}\n"
            print_str += f"{llm_response}{self.stopword}\n"
            # Check if the planner should stop
            # Stop if the replanning count exceed a certain threshold
            # or end expression is found in llm response
            # This is helpful to break infinite planning loop.
            self.is_done = (self.end_expression in llm_response) or (
                self.replanning_count == self.planner_config.replanning_threshold
            )

            # Early return if stop is required
            if self.is_done:
                planner_info = {
                    "replanned": {agent.uid: True for agent in self.agents},
                    "print": print_str,
                    "traces": {agent.uid: self.trace for agent in self.agents},
                    "prompts": {agent.uid: self.curr_prompt for agent in self.agents},
                    "replanning_count": {
                        agent.uid: self.replanning_count for agent in self.agents
                    },
                    "replan_required": {
                        agent.uid: self.replan_required for agent in self.agents
                    },
                    "is_done": {agent.uid: self.is_done for agent in self.agents},
                    "thought": {agent.uid: thought for agent in self.agents},
                    "high_level_actions": {
                        agent.uid: ("Done", None, None) for agent in self.agents
                    },
                }
                return {}, planner_info, self.is_done

            # Increment the llm call counter on every replan
            self.replanning_count += 1

            # Parse high level action directives from llm response
            high_level_actions = self.actions_parser(
                self.agents, llm_response, self.params
            )

            # Get low level actions and/or responses
            low_level_actions, responses = self.process_high_level_actions(
                high_level_actions, observations
            )

            # Store last executed high level action
            self.last_high_level_actions = high_level_actions
        else:
            # Set thought to None
            thought = None

            # Get low level actions and/or responses using last high level actions
            low_level_actions, responses = self.process_high_level_actions(
                self.last_high_level_actions, observations
            )

        # Log if replanning was done or not before overwriting the value
        planner_info["replan_required"] = {
            agent.uid: self.replan_required for agent in self.agents
        }

        # Check if replanning is required
        # Replanning is required when any of the actions being executed
        # have a response indicating success or failure (and the reason)
        self.replan_required = any(responses.values())

        # Add responses to the print
        for agent_uid in sorted(responses.keys()):
            # If the response for a given agent is valid, add to the prompt and printout
            if responses[agent_uid]:
                # Update print string
                print_str += (
                    f"""Agent_{agent_uid}_observation:{responses[agent_uid]}\n"""
                )

            # If the response is empty then indicate the action is still in progress
            # only when replanning was required
            elif self.replan_required:
                responses[
                    agent_uid
                ] = f"Action {self.last_high_level_actions[agent_uid][0]}[{self.last_high_level_actions[agent_uid][1]}] is still in progress."

                # Update print string
                print_str += (
                    f"""Agent_{agent_uid}_observation:{responses[agent_uid]}\n"""
                )

            # save agent observations to get feedback on skill execution
            self.latest_agent_response[agent_uid] = responses[agent_uid]

        # Update planner info
        self.trace += print_str
        planner_info["responses"] = responses
        planner_info["thought"] = {agent.uid: thought for agent in self.agents}
        planner_info["print"] = print_str
        planner_info["high_level_actions"] = self.last_high_level_actions
        planner_info["traces"] = {agent.uid: self.trace for agent in self.agents}
        planner_info["prompts"] = {agent.uid: self.curr_prompt for agent in self.agents}
        planner_info["replanning_count"] = {
            agent.uid: self.replanning_count for agent in self.agents
        }
        planner_info["agent_states"] = self.get_last_agent_states()
        planner_info["agent_positions"] = self.get_last_agent_positions()
        planner_info["agent_collisions"] = self.get_agent_collisions()
        planner_info["is_done"] = {agent.uid: self.is_done for agent in self.agents}
        return low_level_actions, planner_info, self.is_done
