from glob import glob
import os
import shutil
import threading
from typing import List
from habitat.utils.visualizations import maps
from home_robot.core.interfaces import Observations
import imageio
from natsort import natsorted
import numpy as np
import nvtx
from tqdm import tqdm
from vlfm.utils.img_utils import reorient_rescale_map, resize_images
from vlfm.utils.visualization import add_text_to_image
from vlfm.utils.habitat_visualizer import HabitatVis, color_point_cloud_on_map
import cv2

from helios.agent.nav.vlfm.vlfm_map_policy import VlfmMapPolicy


class SavingThread(threading.Thread):
    def __init__(self, info, map_policy: VlfmMapPolicy, vis_dir, frame_idx):
        self.info = info
        self.map_policy = map_policy
        self.vis_dir = vis_dir
        self.frame_idx = frame_idx
        super().__init__()
    
    def run(self):
        self.policy_info = self.map_policy.get_policy_info(visualize=True)
        target_objects = self.policy_info['target_objects']

        if "annotated_depth" in self.policy_info:
            depth = self.policy_info["annotated_depth"]
            self.using_annotated_depth = True
        else:
            depth = (self.info['depth'] * 255.0).astype(np.uint8)
            depth = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB)

        if "annotated_rgb" in self.policy_info:
            rgb = self.policy_info["annotated_rgb"]
            self.using_annotated_rgb = True
        else:
            rgb = self.info['rgb']

        # Visualize target point cloud on the map
        top_down_map = draw_top_down_map(
            self.info["top_down_map"],
            self.policy_info['target_point_clouds'],
            depth.shape[0]
        )

        y0, y1, x0, x1 = self.info['bounds']
        vis_map_imgs = [reorient_rescale_map(self.policy_info['obstacle_map'][y0:y1, x0:x1])] + [
            reorient_rescale_map(value_map)# if i_map < 2 else value_map 
            for i_map, value_map in enumerate(self.policy_info['value_maps'])
        ] + ([] if 'merged_object_map' not in self.info else [
            reorient_rescale_map(self.info['merged_object_map'].visualize())
        ])
        assert 'merged_object_map' in self.info
        text = [
            self.policy_info[text_key]
            for text_key in self.policy_info.get("render_below_images", [])
            if text_key in self.policy_info
        ]

        frame = self._create_frame(depth, rgb, top_down_map, vis_map_imgs, text, target_objects)
        frame = frame[::3, ::3]
        os.makedirs(self.vis_dir, exist_ok=True)
        imageio.imwrite(f"{self.vis_dir}/vlfm_snapshot_{self.frame_idx:04d}.png", frame)

        self.info = None
        self.policy_info = None

    def _create_frame(
        self,
        depth: np.ndarray,
        rgb: np.ndarray,
        map: np.ndarray,
        vis_map_imgs: List[np.ndarray],
        text: List[str],
        target_objects: List[str],
    ) -> np.ndarray:
        all_imgs = [depth, rgb, map] + vis_map_imgs
        all_imgs = resize_images(all_imgs, match_dimension="height")
        for i_img, img, label in zip(
            range(len(all_imgs)),
            all_imgs,
            ['depth', 'rgb', 'map', 'obstacle'] + target_objects + ['seen', 'object']
        ):
            color = (255, 255, 255) if i_img < 2 else (0, 0, 0)
            img = cv2.putText(img, label, (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
        if len(all_imgs) % 2 == 1:
            all_imgs.append(np.ones_like(all_imgs[-1]) * 255)
        
        frame = np.concatenate([
            np.concatenate(all_imgs[::2], axis=1),
            np.concatenate(all_imgs[1::2], axis=1)
        ], axis=0)

        # Add text to the top of the frame
        for t in text[::-1]:
            frame = add_text_to_image(frame, t, top=True)
        
        return frame


class SavingHabitatVis:
    def __init__(self, image_dir: str):
        self.image_dir = image_dir
        self.threads = []
        self.frame_idx = 0

    def reset(self,):
        [thread.join() for thread in self.threads]
        self.threads = []
        self.frame_idx = 0

    def set_vis_dir(self, scene_id, episode_id):
        self.vis_dir = os.path.join(self.image_dir, f"{scene_id}_{episode_id}")
        shutil.rmtree(self.vis_dir, ignore_errors=True)
        os.makedirs(self.vis_dir, exist_ok=True)
    
    @nvtx.annotate("SavingHabitatVis.collect_data")
    def collect_data(
        self,
        info,
        map_policy: VlfmMapPolicy,
    ) -> None:
        self.frame_idx = info["timestep"]
        self.add_and_start_thread(SavingThread(
            info, map_policy, self.vis_dir, self.frame_idx
        ))

    def add_and_start_thread(self, thread: threading.Thread):
        self.threads.append(thread)
        self.threads[-1].start()

    def flush_frames(self) -> List[np.ndarray]:
        [thread.join() for thread in self.threads]
        self.threads = []
        return self
    
    def __len__(self):
        return self.frame_idx+1

    def __iter__(self):
        paths = natsorted(glob(f'{self.vis_dir}/vlfm_snapshot_*.png'))
        for path in tqdm(paths):
            frame = imageio.imread(path)
            yield frame


def draw_top_down_map(top_down_map_dict, target_point_clouds, top_down_height):
    top_down_map = top_down_map_dict["map"]
    top_down_map = maps.colorize_topdown_map(
        top_down_map, top_down_map_dict["fog_of_war_mask"]
    )
    top_down_map = maps.draw_agent(
        image=top_down_map,
        agent_center_coord=top_down_map_dict["agent_map_coord"][0],
        agent_rotation=top_down_map_dict["agent_angle"][0],
        agent_radius_px=10,
    )
    for target_point_cloud, color in zip(
        target_point_clouds,
        [(255, 0, 0), (0, 255, 0), (0, 0, 255)],
    ):
        if len(target_point_cloud):
            xy_points = target_point_cloud[:, :2]
            pixel_points = np.rint(xy_points[..., ::-1] * 20) + 500
            pixel_points = 1000 - pixel_points
            pixel_points = pixel_points.astype(np.int32)
            top_down_map[pixel_points[:, 1], pixel_points[:, 0]] = color

    y0, x0 = top_down_map_dict['lower_bound']
    y1, x1 = top_down_map_dict['upper_bound']
    top_down_map = top_down_map[-y1:-y0, x0:x1]
    old_h, old_w, _ = top_down_map.shape
    top_down_width = int(float(top_down_height) / old_h * old_w)
    # cv2 resize (dsize is width first)
    top_down_map = cv2.resize(
        top_down_map,
        (top_down_width, top_down_height),
        interpolation=cv2.INTER_CUBIC,
    )
    return top_down_map