from typing import List
from habitat.utils.visualizations import maps
from home_robot.core.interfaces import Observations
import numpy as np
from tqdm import tqdm
from vlfm.utils.img_utils import reorient_rescale_map, resize_images
from vlfm.utils.visualization import add_text_to_image
from vlfm.utils.habitat_visualizer import HabitatVis, color_point_cloud_on_map
import cv2

class GeneratingHabitatVis(HabitatVis):
    def collect_data(
        self,
        info,
        policy_info,
    ) -> None:

        self.target_objects = policy_info['target_objects']

        if "annotated_depth" in policy_info:
            depth = policy_info["annotated_depth"]
            self.using_annotated_depth = True
        else:
            depth = (info['depth'] * 255.0).astype(np.uint8)
            depth = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB)
        self.depth.append(depth)

        if "annotated_rgb" in policy_info:
            rgb = policy_info["annotated_rgb"]
            self.using_annotated_rgb = True
        else:
            rgb = info['rgb']
        self.rgb.append(rgb)

        # Visualize target point cloud on the map
        top_down_map = draw_top_down_map(
            info["top_down_map"],
            policy_info['target_point_clouds'],
            self.depth[0].shape[0]
        )
        self.maps.append(top_down_map)

        y0, y1, x0, x1 = info['bounds']

        vis_map_imgs = [reorient_rescale_map(self.policy_info['obstacle_map'][y0:y1, x0:x1])] + [
            reorient_rescale_map(value_map) if i_map < 2 else value_map 
            for i_map, value_map in enumerate(self.policy_info['value_maps'])
        ]
        if vis_map_imgs:
            self.using_vis_maps = True
            self.vis_maps.append(vis_map_imgs)
        text = [
            policy_info[text_key]
            for text_key in policy_info.get("render_below_images", [])
            if text_key in policy_info
        ]
        self.texts.append(text)

    def _create_frame(
        self,
        depth: np.ndarray,
        rgb: np.ndarray,
        map: np.ndarray,
        vis_map_imgs: List[np.ndarray],
        text: List[str],
    ) -> np.ndarray:
        all_imgs = [depth, rgb, map] + vis_map_imgs
        all_imgs = resize_images(all_imgs, match_dimension="height")
        for i_img, img, label in zip(
            range(len(all_imgs)),
            all_imgs,
            ['depth', 'rgb', 'map', 'obstacle'] + self.target_objects + ['seen', 'gaze']
        ):
            color = (255, 255, 255) if i_img < 2 else (0, 0, 0)
            img = cv2.putText(img, label, (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
        if len(all_imgs) % 2 == 1:
            all_imgs.append(np.ones_like(all_imgs[-1]) * 255)
        
        frame = np.concatenate([
            np.concatenate(all_imgs[::2], axis=1),
            np.concatenate(all_imgs[1::2], axis=1)
        ], axis=0)

        # Add text to the top of the frame
        for t in text[::-1]:
            frame = add_text_to_image(frame, t, top=True)

        return frame


    def flush_frames(self, failure_cause: str) -> List[np.ndarray]:
        self.failure_cause = failure_cause
        return self
    
    def __len__(self):
        return len(self.depth) - 1

    def __iter__(self):
        # if self.using_annotated_rgb is not None:
        #     self.rgb.append(self.rgb.pop(0))
        # if self.using_annotated_depth is not None:
        #     self.depth.append(self.depth.pop(0))
        # if self.using_vis_maps:  # Cost maps are also one step delayed
        #     self.vis_maps.append(self.vis_maps.pop(0))

        num_frames = len(self)  # last frame is from next episode, remove it
        for i in tqdm(range(num_frames)):
            frame = self._create_frame(
                self.depth[i],
                self.rgb[i],
                self.maps[i],
                self.vis_maps[i],
                self.texts[i],
            )
            failure_cause_text = "Failure cause: " + self.failure_cause
            frame = add_text_to_image(frame, failure_cause_text, top=True)
            yield frame


def draw_top_down_map(top_down_map_dict, target_point_clouds, top_down_height):
    top_down_map = top_down_map_dict["map"]
    top_down_map = maps.colorize_topdown_map(
        top_down_map, top_down_map_dict["fog_of_war_mask"]
    )
    top_down_map = maps.draw_agent(
        image=top_down_map,
        agent_center_coord=top_down_map_dict["agent_map_coord"][0],
        agent_rotation=top_down_map_dict["agent_angle"][0],
        agent_radius_px=10,
    )
    for target_point_cloud, color in zip(
        target_point_clouds,
        [(255, 0, 0), (0, 255, 0), (0, 0, 255)],
    ):
        if len(target_point_cloud):
            xy_points = target_point_cloud[:, :2]
            pixel_points = np.rint(xy_points[..., ::-1] * 20) + 500
            pixel_points = 1000 - pixel_points
            pixel_points = pixel_points.astype(np.int32)
            top_down_map[pixel_points[:, 1], pixel_points[:, 0]] = color

    y0, x0 = top_down_map_dict['lower_bound']
    y1, x1 = top_down_map_dict['upper_bound']
    top_down_map = top_down_map[-y1:-y0, x0:x1]
    old_h, old_w, _ = top_down_map.shape
    top_down_width = int(float(top_down_height) / old_h * old_w)
    # cv2 resize (dsize is width first)
    top_down_map = cv2.resize(
        top_down_map,
        (top_down_width, top_down_height),
        interpolation=cv2.INTER_CUBIC,
    )
    return top_down_map