#!/usr/bin/env python3



from typing import Any, Dict, List

import cv2
import imageio
import numpy as np

from partnr.perception.perception_sim import PerceptionSim


def as_intrinsics_matrix(intrinsics):
    """
    Get matrix representation of intrinsics.
    input: fx, fy, cx, cy

    """
    K = np.eye(3)
    K[0, 0] = intrinsics[0]
    K[1, 1] = intrinsics[1]
    K[0, 2] = intrinsics[2]
    K[1, 2] = intrinsics[3]
    return K


class PerceptionObs(PerceptionSim):
    """
    This class uses only the simulated panoptic sensors to detect objects and then
    ground there location based on depth images being streamed by the agents. Note that
    no other privileged information about the state of the world is used to enhance
    object location or inter-object relations. We use previously detected objects
    and furniture through CG to ground properties of newer objects detected through
    panoptic sensors.
    """

    def __init__(self, sim, metadata_dict: Dict[str, str], *args, **kwargs):
        super().__init__(sim, metadata_dict=metadata_dict, detectors=["gt_panoptic"])

        # a list of cached images for debugging
        self._outs: List[Any] = []
        self._iteration = 0
        self._verbose = True

    def get_sim_handle_and_key_from_panoptic_image(
        self, obs: np.ndarray
    ) -> Dict[int, str]:
        """
        This method uses the instance segmentation output to
        create a list of handles of all objects present in given agent's FOV
        """

        handles_to_key_map: Dict[int, str] = {}

        unique_obj_ids = np.unique(obs)
        # 100 gets added to object IDs that are recognized by ROM/AOM in sim/lab
        # subtracting 100 here to get the original object ID
        unique_obj_ids = [idx - 100 for idx in unique_obj_ids if idx != 0]
        for obj_idx in unique_obj_ids:
            if obj_idx != 0:
                sim_object = self.rom.get_object_by_id(obj_idx)
                if sim_object is not None:
                    # add 100 to object ID to get the object ID recognized by ROM/AOM
                    # this is done so code using `obs` as is doesn't have to care about
                    # adding or subtracting 100 on their end. User sees indices in `obs`
                    # mapped to handles directly
                    handles_to_key_map[obj_idx + 100] = sim_object.handle

        return handles_to_key_map

    def process_obs(self, obs: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process the observation to extract the necessary information for the
        detectors and VLMs.
        """
        processed_obs = {}
        # extract RGB from obs and pass it to detectors
        out_img = obs["masks"]
        in_img = obs["rgb"]
        # imageio.imwrite("out_img.png", out_img)
        self._outs.append(out_img)

        # Get handles of all objects and receptacles in agent's FOVs
        id_to_object_mapping = self.get_sim_handle_and_key_from_panoptic_image(out_img)

        # Convert handles to names
        pop_list = []
        for obj_idx in id_to_object_mapping:
            if id_to_object_mapping[obj_idx] in self.sim_handle_to_name:
                id_to_object_mapping[obj_idx] = self.sim_handle_to_name[
                    id_to_object_mapping[obj_idx]
                ]
            else:
                pop_list.append(obj_idx)
        for obj_idx in pop_list:
            id_to_object_mapping.pop(obj_idx)

        H, W, _ = out_img.shape
        # semantic_input is a dictionary with object index as key and mask as
        # value. masks are of shape (H, W, 1) with 1s at the location of object

        semantic_input: Dict[int, np.ndarray] = {}

        for _, obj_idx in enumerate(id_to_object_mapping):
            mask = (out_img == obj_idx).astype(np.uint8)
            semantic_input[obj_idx] = mask
            imageio.imwrite(f"in_img_{id_to_object_mapping[obj_idx]}.png", in_img)
            cv2.imwrite(f"mask_{id_to_object_mapping[obj_idx]}.png", in_img * mask)

        # create an output dict to be used by graph updater
        processed_obs[self.detectors[0]] = {
            "object_masks": semantic_input,
            "out_img": out_img,
            "object_category_mapping": id_to_object_mapping,
            "depth": obs["depth"],
            "rgb": obs["rgb"],
            "camera_intrinsics": as_intrinsics_matrix(obs["camera_intrinsics"]),
            "camera_pose": obs["camera_pose"],
        }
        self._iteration += 1
        # write the output of self._outs as a video
        # if self._verbose and len(self._outs) > 0 and self._iteration % 10 == 0:
        #     imageio.mimsave("out_img.gif", self._outs)
        return processed_obs
