#!/usr/bin/env python3



from typing import List

from partnr.agent import Agent


# This class represents an abstract class for any planner.
class Planner:
    def __init__(self, plan_config, env_interface):
        # Set the planner config
        self.planner_config = plan_config
        self.env_interface = env_interface
        self._agents: List[Agent] = []
        self.is_done = False
        if type(plan_config) is str:
            self.enable_rag = False  # This is for passing the pytest since we use string for plan config in test_planner.py
        else:
            self.enable_rag = plan_config.get("enable_rag", False)
        self.swap_instrction = True

    def get_next_action(self, instruction, observations, world_graph):
        """
        Gives the next low level action to execute
        """
        raise NotImplementedError

    @property
    def agent_indices(self):
        """
        The indices of the agent that this planner controls
        """
        return [agent.uid for agent in self._agents]

    def reset(self):
        raise NotImplementedError

    @property
    def agents(self):
        return self._agents

    @agents.setter
    def agents(self, agents):
        self._agents = agents

    @property
    def agent_descriptions(self):
        """Returns a string listing the descriptions of all agents"""

        out = ""
        for agent in self.agents:
            out += agent.agent_description

        return out

    def get_agent_from_uid(self, agent_uid):
        for agent in self.agents:
            if agent.uid == agent_uid:
                return agent
        raise ValueError(f'Agent with uid "{agent_uid}" not found')

    def filter_obs_space(self, batch, agent_uid):
        # TODO: this can go in utils
        # FIXME: code uses both this method as well as env_interface.parse_observations
        # to do the same thing. This is redundant and should be fixed.
        """
        This method returns observations belonging to the specified agent
        """
        if self.env_interface._single_agent_mode:
            return batch
        agent_name = f"agent_{agent_uid}"
        agent_name_bar = f"{agent_name}_"
        output_batch = {
            obs_name.replace(agent_name_bar, ""): obs_value
            for obs_name, obs_value in batch.items()
            if agent_name in obs_name
        }
        return output_batch

    def process_high_level_actions(self, hl_actions, observations):
        # Make sure that the high level actions are not empty
        agent_indices = self.agent_indices
        if not hl_actions:
            response = "No actions were assigned. Please assign action to this agent."
            # actions = {str(agent_ind): None for agent_ind in agent_indices}
            actions = {}
            responses = {agent_ind: response for agent_ind in agent_indices}
            return actions, responses

        # Declare containers for responses and low level actions
        low_level_actions = {}
        responses = {}

        # Iterate through all agents
        for agent in self.agents:
            agent_uid = agent.uid

            if agent_uid in hl_actions:
                # For readability
                hl_action_name = hl_actions[agent_uid][0]
                hl_action_input = hl_actions[agent_uid][1]
                hl_error_message = hl_actions[agent_uid][2]

                # Handle error message
                if hl_error_message:
                    responses[agent_uid] = hl_error_message
                    continue

                # Fetch agent specific observations
                filtered_observations = self.filter_obs_space(observations, agent_uid)
                # Get response and/or low level actions
                low_level_action, response = agent.process_high_level_action(
                    hl_action_name, hl_action_input, filtered_observations
                )

                # Insert to the output
                if low_level_action is not None:
                    low_level_actions[agent_uid] = low_level_action
                responses[agent_uid] = response.rstrip("\n")

        # update world based on actions
        self.update_world(responses)

        return low_level_actions, responses

    def update_world(self, responses):
        """
        Update the world graph with the latest observations and actions. Notes this is
        only required for partial-observability case, this function does NOTHING under
        full observability.

        Full observability condition does not need an update due to actions.
        Action-based-updates were necessary because in partial-obs the object is not
        visible while being carried so is dropped from "agent-is-holding" relation.
        """
        if self.env_interface.partial_obs:
            self._partial_obs_update(responses)

    def _partial_obs_update(self, responses):
        """
        Logic for updating each agent's graph wrt other agent's actions for both CG and
        GT confitions under partial-observable setting
        """
        composite_action_response = self.env_interface._composite_action_response
        for agent_uid in self.last_high_level_actions:
            action_and_args = None
            action_results = None
            int_agent_uid = int(agent_uid)
            if agent_uid in responses or int_agent_uid in composite_action_response:
                if int_agent_uid in composite_action_response:
                    action_and_args = composite_action_response[int_agent_uid]
                    action_results = action_and_args[2]
                    # reset to empty out this variable
                    self.env_interface.reset_composite_action_response()
                elif agent_uid in responses:
                    action_and_args = self.last_high_level_actions[agent_uid]
                    action_results = responses[agent_uid]
                int_other_agent_uid = 1 - int_agent_uid
                # update own and other's world-graph
                # -----------------------
                # TODO: need to assess if RL skills need this action-based-update at all
                # update own WG w/own action
                self.env_interface.world_graph[int_agent_uid].update_by_action(
                    agent_uid,
                    action_and_args,
                    action_results,
                )

                # update other agent's graph with current agent's actions
                # NOTE: this is a separate function since two agents may refer to the
                # same entity using different descriptions. This function call handles
                # that ambiguity
                if (
                    self.env_interface.conf.agent_asymmetry
                    and int_agent_uid == self.env_interface.human_agent_uid
                ) or (not self.env_interface.conf.agent_asymmetry):
                    # only update robot's WG with other agent's actions
                    # OR
                    # add action based updates irrespective of agent types
                    self.env_interface.world_graph[
                        int_other_agent_uid
                    ].update_by_other_agent_action(
                        agent_uid,
                        action_and_args,
                        action_results,
                    )
                # -----------------------
