#!/usr/bin/env python3



import random
from typing import Any, Dict, List, Union

import numpy as np
import torch

from partnr.utils.geometric import (
    opengl_to_opencv,
    unproject_masked_depth_to_xyz_coordinates,
)
from partnr.world_model import (
    Entity,
    Floor,
    Furniture,
    House,
    Human,
    Object,
    Room,
    SpotRobot,
    UncategorizedEntity,
    WorldGraph,
)
from partnr.world_model.world_graph import flip_edge


class DynamicWorldGraph(WorldGraph):
    """
    This derived class collects all methods specific to world-graph created and
    maintained based on observations instead of privileged sim data.
    """

    def __init__(
        self,
        max_neighbors: int = 5,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.max_neighbors_for_room_assignment = max_neighbors
        self.include_objects = False
        self._entity_names: List[str] = []

    def _cg_object_to_object_uid(self, cg_object: dict) -> str:
        return f"{cg_object['id']}_{cg_object['object_tag'].replace(' ', '_').replace('/', '_or_').replace('-', '_')}"

    def create_cg_edges(
        self,
        cg_dict_list: dict = None,
        include_objects: bool = False,
        verbose: bool = False,
    ):
        """
        This method populates the graph from the dict output of CG

        Creates a graph to store different entities in the world
        and their relations to one another
        example of what 1 CG relation looks like:
        {
            'object1': {'bbox_center': [-5.3, 0.9, 6.1],
                        'bbox_extent': [0.5, 0.4, 0.1],
                        'id': 215,
                        'object_tag': 'television set',
                        'category_tag': 'object'},
            'object2': {'bbox_center': [-5.3, 0.4, 6.1],
                        'bbox_extent': [0.5, 0.3, 0.1],
                        'id': 216,
                        'object_tag': 'couch',
                        'category_tag': 'receptacle'},
            'object_relation': 'a next to b',
            'reason': 'television set and couch are commonly found in living room, and '
                        'they are typically placed next to each other.',
            'room_region': 'living room'
            }
        - object_tag may be invalid, in that case just add non-invalid object as a node
        - object_relation may be:
            "none of these", "a/b next to b/a", "a/b on b/a", "a/b in b/a"
        - room_region may be unknown, in that case just add unknown room as a node and add objects to that node
        """

        self.include_objects = include_objects
        self._raw_cg = cg_dict_list

        def to_entity_input(obj):
            return {
                "name": self._cg_object_to_object_uid(obj),
                "properties": {
                    "type": obj["category_tag"],
                    "translation": [obj["bbox_center"][i] for i in [0, 2, 1]],
                    "bbox_extent": obj["bbox_extent"],
                },
            }

        def is_valid_obj_or_furniture(obj, include_objects):
            # check that object is valid and not a wall or floor
            tag_is_OK: bool = (
                obj["object_tag"] != "invalid"
                and "floor" not in obj["object_tag"]
                and "wall" not in obj["object_tag"]
            )
            # check that object is an object or furniture
            is_furniture: bool = tag_is_OK and obj["category_tag"] == "furniture"
            is_object: bool = tag_is_OK and obj["category_tag"] == "object"
            return is_furniture or (is_object and include_objects)

        # Create root node
        house = House("house", {"type": "root"}, "house_0")
        self.add_node(house)
        self._entity_names.append("house")

        for edge_candidate in cg_dict_list:
            object1 = edge_candidate["object1"]
            object2 = edge_candidate["object2"]
            edge_relation = edge_candidate["object_relation"]
            room_region = edge_candidate["room_region"].replace(" ", "_")
            if verbose:
                print("RAW CG OUTPUT:")
                print(edge_candidate)
            room_node = None
            try:
                room_node = self.get_node_from_name(room_region)
            except ValueError as e:
                print(e)
            if room_node is None and room_region != "FAIL":
                room_node = Room(
                    **{"properties": {"type": room_region}, "name": room_region}
                )
                self.add_node(room_node)
                self._entity_names.append(room_region)
                self.add_edge(
                    room_node, house, "inside", opposite_label=flip_edge("inside")
                )
                room_floor = Floor(f"floor_{room_node.name}", {})
                self.add_node(room_floor)
                self._entity_names.append(room_floor.name)
                self.add_edge(room_floor, room_node, "inside", flip_edge("inside"))
                if verbose:
                    print(f"Added new room: {room_node.name}")
            object_nodes: List[Entity] = []
            for obj in [object1, object2]:
                obj_uid = self._cg_object_to_object_uid(obj)
                if (
                    is_valid_obj_or_furniture(obj, include_objects)
                    and obj_uid not in self._entity_names
                ):
                    obj_entity_input_dict = to_entity_input(obj)
                    if obj["category_tag"] == "object":
                        object_nodes.append(Object(**obj_entity_input_dict))
                        self.add_node(object_nodes[-1])
                        self._entity_names.append(obj_entity_input_dict["name"])
                    elif obj["category_tag"] == "furniture":
                        object_nodes.append(Furniture(**obj_entity_input_dict))
                        self.add_node(object_nodes[-1])
                        self._entity_names.append(obj_entity_input_dict["name"])
                    elif obj["category_tag"] == "invalid":
                        object_nodes.append(
                            UncategorizedEntity(**obj_entity_input_dict)
                        )
                        self.add_node(object_nodes[-1])
                        self._entity_names.append(obj_entity_input_dict["name"])
                    if verbose:
                        print(f"Added new entity: {object_nodes[-1].name}")
                    # make a child of room_region allocated
                    if room_region != "FAIL":
                        self.add_edge(
                            obj_entity_input_dict["name"],
                            room_node.name,
                            "inside",
                            opposite_label=flip_edge("inside"),
                        )
                        if verbose:
                            print(f"Added above object to room: {room_node.name}")
                elif obj_uid in self._entity_names:
                    object_nodes.append(self.get_node_from_name(obj_uid))
                    if verbose:
                        print(f"Found existing entity: {object_nodes[-1].name}")
            # add edge between object1 and object2
            if len(object_nodes) == 2:
                if edge_relation in ["none of these", "FAIL"]:
                    continue

                if "next to" in edge_relation:
                    self.add_edge(
                        object_nodes[0],
                        object_nodes[1],
                        "next to",
                        opposite_label=flip_edge("next to"),
                    )
                elif edge_relation == "a on b":
                    self.add_edge(
                        object_nodes[0],
                        object_nodes[1],
                        "on",
                        opposite_label=flip_edge("on"),
                    )
                elif edge_relation == "b on a":
                    self.add_edge(
                        object_nodes[1],
                        object_nodes[0],
                        "on",
                        opposite_label=flip_edge("on"),
                    )
                elif edge_relation == "a in b":
                    self.add_edge(
                        object_nodes[0],
                        object_nodes[1],
                        "inside",
                        opposite_label=flip_edge("inside"),
                    )
                elif edge_relation == "b in a":
                    self.add_edge(
                        object_nodes[1],
                        object_nodes[0],
                        "inside",
                        opposite_label=flip_edge("inside"),
                    )
                else:
                    raise ValueError(
                        f"Unknown edge candidate: {edge_relation}, between objects: {object1} and {object2}"
                    )
                if verbose:
                    print(
                        f"Added edge {edge_relation} b/w {object_nodes[0].name} and {object_nodes[1].name}"
                    )
        if verbose:
            print("[DynamicWorldGraph.__init__] Before pruning")
            self.display_hierarchy()
        self._fix_furniture_without_assigned_room()
        self._clean_up_room_and_floor_locations()
        if verbose:
            print("[DynamicWorldGraph.__init__] After pruning")
            self.display_hierarchy()

    def _fix_furniture_without_assigned_room(self):
        """
        Makes sure each furniture is assigned to some room; default=unknown
        """
        furnitures = self.get_all_furnitures()
        all_rooms = self.get_all_nodes_of_type(Room)
        default_room = None
        for room in all_rooms:
            if "unknown" in room.name:
                default_room = room
                break
        for fur in furnitures:
            room = self.get_neighbors_of_type(fur, Room)
            if len(room) == 0:
                self.add_edge(default_room, fur, "in", flip_edge("in"))
        fur_room_count = [
            1 if len(self.get_neighbors_of_type(fur, Room)) > 0 else 0
            for fur in furnitures
        ]
        assert sum(fur_room_count) == len(fur_room_count)

    def _clean_up_room_and_floor_locations(self):
        """
        Iterates over the now-filled graph and attaches positions of known furniture
        belonging to a room to the floor of that room as translation property.
        Also prunes out rooms without any object/furniture in it.

        We use furniture in a room to set the room's location, if a room does not have
        furniture we can't get this geometric information and hence we remove such rooms.
        """
        # find rooms without any furniture in it
        prune_list = []
        rooms = self.get_all_rooms()
        for current_room in rooms:
            furniture = self.get_neighbors_of_type(current_room, Furniture)
            random.shuffle(furniture)
            # remove rooms with just a floor edge or no edges to furniture
            if furniture is None:
                prune_list.append(current_room)
            elif len(furniture) == 1 and isinstance(furniture[0], Floor):
                if isinstance(furniture[0], Floor):
                    room_floor = furniture[0]
                    prune_list.append(room_floor)
                prune_list.append(current_room)
            else:
                # if a room has furniture then choose an arbitrary one which has
                # translation and set the location of the floor and the room
                # to be same as this furniture
                room_floor = [fnode for fnode in furniture if isinstance(fnode, Floor)][
                    0
                ]
                valid_translation = None
                for fur in furniture:
                    if "translation" in fur.properties:
                        valid_translation = fur.properties["translation"]
                        break
                room_floor.properties["translation"] = valid_translation
                current_room.properties["translation"] = valid_translation

        for prune_room in prune_list:
            self.remove_node(prune_room)

    def add_agent_node_and_update_room(self, agent_node: Union[Human, SpotRobot]):
        self.add_node(agent_node)
        self._entity_names.append(agent_node.name)
        room_node = self.find_room_of_entity(agent_node)
        if room_node is None:
            raise ValueError(
                f"[DynamicWorldGraph.initialize_agent_nodes] No room found for {agent_node.name}"
            )
        self.add_edge(agent_node, room_node, "in", opposite_label="contains")

    def initialize_agent_nodes(self, subgraph: WorldGraph, init: bool = False):
        """
        Initializes or updates the agent node in the graph
        """
        human_node = subgraph.get_all_nodes_of_type(Human)
        if len(human_node) == 0:
            print("[DynamicWorldGraph.update_agent_nodes] No human node found")
        else:
            human_node = human_node[0]
            dynamic_human_node = Human(human_node.name, {"type": "agent"})
            dynamic_human_node.properties["translation"] = human_node.properties[
                "translation"
            ].copy()
            self.add_agent_node_and_update_room(dynamic_human_node)

        agent_node = subgraph.get_all_nodes_of_type(SpotRobot)
        if len(agent_node) == 0:
            print("[DynamicWorldGraph.update_agent_nodes] No SpotRobot node found")
        else:
            agent_node = agent_node[0]
            dynamic_agent_node = SpotRobot(agent_node.name, {"type": "agent"})
            dynamic_agent_node.properties["translation"] = agent_node.properties[
                "translation"
            ].copy()
            self.add_agent_node_and_update_room(dynamic_agent_node)

    def find_room_of_entity(
        self, entity_node: Union[Human, SpotRobot], verbose: bool = False
    ) -> Room:
        """
        This method finds the room node that the agent is in

        Logic: Find the objects closest to the agent and assign the agent to the room
        that contains the most number of these objects
        """
        room_node = None
        closest_objects = self.get_closest_entities(
            self.max_neighbors_for_room_assignment,
            object_node=entity_node,
            dist_threshold=-1.0,
        )
        room_counts: Dict[Room, int] = {}
        for obj in closest_objects:
            for room in self.get_neighbors_of_type(obj, Room):
                if verbose:
                    print(
                        f"[DynamicWorldGraph.find_room_of_entity] {entity_node.name} --> Closest object: {obj.name} is in room: {room.name}"
                    )
                if room in room_counts:
                    room_counts[room] += 1
                else:
                    room_counts[room] = 1
        if room_counts:
            if verbose:
                print(f"{room_counts=}")
            room_node = max(room_counts, key=room_counts.get)
        return room_node

    def move_object_from_agent_to_placement_node(
        self,
        object_node: Union[Entity, Object],
        agent_node: Union[Entity, Human, SpotRobot],
        placement_node: Union[Entity, Furniture],
    ):
        """
        Utility method to move object to a placement node from a given agent. Does in-place manipulation of the world-graph
        """
        # Detach the object from the agent
        self.remove_edge(object_node, agent_node)

        # Add new edge from object to the receptacle
        # TODO: We should add edge to default receptacle instead of fur
        self.add_edge(object_node, placement_node, "on", flip_edge("on"))
        # snap the object to furniture's center in absence of actual location
        object_node.properties["translation"] = placement_node.properties["translation"]

    def update_by_obs(self, frame_desc: Dict[str, Any], verbose: bool = False):
        """
        This method updates the graph based on the processed observations
        Input:
        - frame_desc: dictionary containing the processed observations, keys:
          - "objects":      list of objects in the frame
          - "masks":        list of masked images outlining each object
          - "relations":    list of relations between the objects
          - "depth":        depth image of the frame
          - "intrinsics":   camera intrinsics for depth-image
          - "camera_pose":  camera pose in the world frame
        """
        # create masked point-clouds per object and then extract centroid
        # as a proxy for object's location
        # NOTE: using bboxes may also include non-object points to contribute
        # to the object's position...we can fix this with nano-SAM or using
        # analytical approaches to prune object PCD
        for _detector_name, detector_frame in frame_desc.items():
            if detector_frame["object_category_mapping"]:
                depth_numpy = detector_frame["depth"]
                H, W, C = depth_numpy.shape
                pose = opengl_to_opencv(detector_frame["camera_pose"])
                depth_tensor = torch.from_numpy(depth_numpy.reshape(1, C, H, W))
                pose_tensor = torch.from_numpy(pose.reshape(1, 4, 4))
                inv_intrinsics_tensor = torch.from_numpy(
                    np.linalg.inv(detector_frame["camera_intrinsics"]).reshape(1, 3, 3)
                )
                obj_id_to_category_mapping = detector_frame["object_category_mapping"]
                for object_id, object_mask in detector_frame["object_masks"].items():
                    if np.any(object_mask):
                        if verbose:
                            print(
                                f"Found object: {obj_id_to_category_mapping[object_id]} with id: {object_id}"
                            )
                        mask_tensor = torch.from_numpy(object_mask.reshape(1, C, H, W))
                        mask_tensor = ~mask_tensor.bool()
                        object_xyz = unproject_masked_depth_to_xyz_coordinates(
                            depth_tensor,
                            pose_tensor,
                            inv_intrinsics_tensor,
                            mask_tensor,
                        )
                        object_centroid = object_xyz.mean(dim=0).numpy()

                        # add this object to the graph
                        # TODO: this is just a proof-of-concept, we need to use a better
                        # vocab for object detection here and need to discuss if we want to
                        # only add new objects but also furniture pieces that may have been
                        # missed during pre-exploration
                        new_object_node = Object(
                            f"{object_id}_{obj_id_to_category_mapping[object_id]}",
                            {
                                "type": obj_id_to_category_mapping[object_id],
                                "translation": object_centroid.tolist(),
                                "camera_pose_of_view": detector_frame["camera_pose"],
                            },
                        )

                        # add an edge to the closest room to this object
                        # get top N closest objects (N defined by self.max_neighbors_for_room_assignment)
                        # TODO: get only the 5 closest objects viewable from current location?
                        closest_objects = self.get_closest_object_or_furniture(
                            new_object_node, self.max_neighbors_for_room_assignment
                        )
                        # find most common room among these objects
                        room_counts: Dict[Union[Object, Furniture], int] = {}
                        for obj in closest_objects:
                            for room in self.get_neighbors_of_type(obj, Room):
                                if verbose:
                                    print(
                                        f"Adding {new_object_node.name} --> Closest object: {obj.name} is in room: {room.name}"
                                    )
                                if room in room_counts:
                                    room_counts[room] += 1
                                else:
                                    room_counts[room] = 1
                                # only use the first Room neighbor, i.e. closest room node
                                break
                        self.add_node(new_object_node)
                        self._entity_names.append(new_object_node.name)
                        if room_counts:
                            closest_room = max(room_counts, key=room_counts.get)
                            self.add_edge(
                                new_object_node,
                                closest_room,
                                "in",
                                opposite_label="contains",
                            )
                        # TODO: logic for checking containment within another furniture
                        # TODO: logic for checking surface-placement over another furniture

    def update_by_action(
        self,
        agent_uid,
        high_level_action,
        action_response,
        verbose: bool = True,
    ):
        if "success" in action_response.lower():
            print(f"{agent_uid=}: {high_level_action=}, {action_response=}")
            agent_node = self.get_node_from_name(f"agent_{agent_uid}")
            if (
                "place" in high_level_action[0].lower()
                or "rearrange" in high_level_action[0].lower()
            ):
                # update object's new place to be the furniture
                if "place" in high_level_action[0].lower():
                    high_level_actions = high_level_action[1].split(",")
                    # remove the proposition
                    # <spatial_relation>, <furniture/floor to be placed>, <spatial_constraint>, <reference_object>]
                    object_node = self.get_node_from_name(high_level_actions[0].strip())
                    # TODO: Add floor support
                    placement_node = self.get_node_from_name(
                        high_level_actions[2].strip()
                    )
                elif "rearrange" in high_level_action[0].lower():
                    # Split the comma separated pair into object name and receptacle name
                    try:
                        # Handle the case for rearrange proposition usage for place skills
                        high_level_actions = high_level_action[1].split(",")
                        # remove the proposition
                        # <spatial_relation>, <furniture/floor to be placed>, <spatial_constraint>, <reference_object>]
                        high_level_actions = [
                            high_level_actions[0],
                            high_level_actions[2],
                        ]
                        object_node, placement_node = [
                            self.get_node_from_name(value.strip())
                            for value in high_level_actions
                        ]
                    except Exception as e:
                        print(f"Issue when split comma: {e}")
                else:
                    raise ValueError(
                        f"Cannot update world graph with action {high_level_action}"
                    )

                # TODO: replace following with the right inside/on relation
                # based on 2nd string argument to Pick when implemented
                # TODO: Temp hack do not add something in placement_node if it is None
                if placement_node is not None:
                    self.move_object_from_agent_to_placement_node(
                        object_node, agent_node, placement_node
                    )
                    if verbose:
                        print(
                            f"{self.update_by_action.__name__} Moved object: {object_node.name} from {agent_node.name} to {placement_node.name}"
                        )
                else:
                    if verbose:
                        print(
                            f"{self.update_by_action.__name__} Could not move object from agent to placement-node: {high_level_action}"
                        )
            elif (
                "pour" in high_level_action[0].lower()
                or "fill" in high_level_action[0].lower()
            ):
                entity_name = high_level_action[1]
                entity_node = self.get_node_from_name(entity_name)
                entity_node.set_state({"is_filled": True})
                if verbose:
                    print(
                        f"[DWG.update_by_action] {entity_node.name} is now filled, {entity_node.properties}"
                    )
            elif "power" in high_level_action[0].lower():
                entity_name = high_level_action[1]
                entity_node = self.get_node_from_name(entity_name)
                if "on" in high_level_action[0].lower():
                    entity_node.set_state({"is_powered_on": True})
                    if verbose:
                        print(
                            f"[DWG.update_by_action] {entity_node.name} is now powered on, {entity_node.properties}"
                        )
                elif "off" in high_level_action[0].lower():
                    entity_node.set_state({"is_powered_on": False})
                    if verbose:
                        print(
                            f"[DWG.update_by_action] {entity_node.name} is now powered off, {entity_node.properties}"
                        )
                else:
                    raise ValueError(
                        "Expected 'on' or 'off' in power action, got: ",
                        high_level_action[0],
                    )
            elif "clean" in high_level_action[0].lower():
                entity_name = high_level_action[1]
                entity_node = self.get_node_from_name(entity_name)
                entity_node.set_state({"is_clean": True})
                if verbose:
                    print(
                        f"[DWG.update_by_action] {entity_node.name} is now clean, {entity_node.properties}"
                    )
            else:
                if verbose:
                    print(
                        "[DWG.update_by_action] Not updating world graph for successful action: ",
                        high_level_action,
                    )
        return

    def _update_cg_by_other_agent_action(
        self,
        other_agent_uid,
        high_level_action_and_args,
        action_results,
        verbose=False,
    ):
        raise NotImplementedError

    def _update_gt_graph_by_other_agent_action(
        self,
        other_agent_uid,
        high_level_action_and_args,
        action_results,
        verbose: bool = True,
    ):
        """
        Uses the exact object and receptacle names given by the other agent to update the
        graph. Supports following actions:
        1. Pick
        2. Place
        """
        if "success" in action_results.lower():
            print(
                f"[DWG._update_gt_graph_by_other_agent_action] {high_level_action_and_args=} {other_agent_uid=}"
            )
            agent_node = self.get_node_from_name(f"agent_{other_agent_uid}")
            # parse out the object-name and the closest furniture-name
            # if the object is not already in the graph, add it
            # if the placement furniture is not already in the graph, add it as a new
            # node
            if (
                "place" in high_level_action_and_args[0].lower()
                or "rearrange" in high_level_action_and_args[0].lower()
            ):
                # <spatial_relation>, <furniture/floor to be placed>, <spatial_constraint>, <reference_object>]
                # get the object from agent properties
                high_level_action_args = high_level_action_and_args[1].split(",")
                object_node = self.get_node_from_name(high_level_action_args[0].strip())
                try:
                    placement_node = self.get_node_from_name(
                        high_level_action_args[2].strip()
                    )
                    self.move_object_from_agent_to_placement_node(
                        object_node, agent_node, placement_node
                    )
                    print(
                        f"[DWG.update_gt_graph_by_other_agent_action] From the perspective of agent_{1-int(other_agent_uid)}:\n{agent_node.name} PLACED OBJECT {object_node.name} on {placement_node.name}"
                    )
                except KeyError as e:
                    print(
                        f"Could not find matching receptacle in agent {1-int(other_agent_uid)} graph for {high_level_action_args[2].strip()} that agent {other_agent_uid} is trying to place on.\nException: {e}"
                    )
            elif (
                "pour" in high_level_action_and_args[0].lower()
                or "fill" in high_level_action_and_args[0].lower()
            ):
                object_name = high_level_action_and_args[1]
                object_node = self.get_node_from_name(object_name)
                object_node.set_state({"is_filled": True})
                if verbose:
                    print(
                        f"[DWG._update_gt_graph_by_other_agent_action] {object_node.name} is now filled, {object_node.properties}"
                    )
            elif "power" in high_level_action_and_args[0].lower():
                object_name = high_level_action_and_args[1]
                object_node = self.get_node_from_name(object_name)
                if "on" in high_level_action_and_args[0].lower():
                    object_node.set_state({"is_powered_on": True})
                    if verbose:
                        print(
                            f"[DWG._update_gt_graph_by_other_agent_action] {object_node.name} is now powered on, {object_node.properties}"
                        )
                elif "off" in high_level_action_and_args[0].lower():
                    object_node.set_state({"is_powered_on": False})
                    if verbose:
                        print(
                            f"[DWG._update_gt_graph_by_other_agent_action] {object_node.name} is now powered off, {object_node.properties}"
                        )
                else:
                    raise ValueError(
                        "Expected 'on' or 'off' in power action, got: ",
                        high_level_action_and_args[0],
                    )
            elif "clean" in high_level_action_and_args[0].lower():
                object_name = high_level_action_and_args[1]
                object_node = self.get_node_from_name(object_name)
                object_node.set_state({"is_clean": True})
                if verbose:
                    print(
                        f"[DWG._update_gt_graph_by_other_agent_action] {object_node.name} is now clean, {object_node.properties}"
                    )
            else:
                if verbose:
                    print(
                        "[DWG._update_gt_graph_by_other_agent_action] Not updating world graph for successful action: ",
                        high_level_action_and_args,
                    )
        return

    def update_by_other_agent_action(
        self,
        other_agent_uid,
        high_level_action_and_args,
        action_results,
        use_semantic_similarity=False,
        verbose=False,
    ):
        if use_semantic_similarity:
            self._update_cg_by_other_agent_action(
                other_agent_uid,
                high_level_action_and_args,
                action_results,
                verbose=verbose,
            )
        else:
            self._update_gt_graph_by_other_agent_action(
                other_agent_uid,
                high_level_action_and_args,
                action_results,
                verbose=verbose,
            )
