from typing import Any, Dict, Tuple, List
import numpy as np
from helios.agent.gaze.gs_object_map import GsObjectMap

from helios.agent.utils.utils import xy_to_px, px_to_xy, gs_to_gps, gps_to_px

class MapInterface:
    def __init__(self, map_shape: np.ndarray, objmap_px_per_m: float):
        self.map_shape = map_shape
        self.objmap_px_per_m = objmap_px_per_m

    def visualize(self, 
        map_img: np.ndarray, 
        obj_rep: GsObjectMap, 
        obj_ind: List[int] = [1,2,3]
    ):
        if obj_rep.initialized:
            for i in range(len(obj_ind)):
                gp = self.recpt_loc_to_map(obj_ind[i], obj_rep, threshold=0.3)
                map_img[:, :, i][gp > 0] = 125
                gp = self.recpt_loc_to_map(obj_ind[i], obj_rep, threshold=0.7)
                map_img[:, :, i][gp > 0] = 255
        return map_img
            
    def recpt_loc_to_map(self, obj_idx: int, obj_rep: GsObjectMap, threshold: float):
        has_recpt, recpt_loc = obj_rep.get_obj_locs(
            obj_class=obj_idx,
            threshold=threshold,
            use_uncertainty=False,
            return_scores=False,
            return_instances=False,
        )

        obj_map = np.zeros(self.map_shape)

        if has_recpt:
            xy_points = gs_to_gps(recpt_loc)
            pixel_points = gps_to_px(xy_points, self.objmap_px_per_m, self.map_shape[0]) #xy_to_px(xy_points, self.objmap_px_per_m, self.map_shape[0])
            mask = (pixel_points[:,0]<self.map_shape[0])*(pixel_points[:,1]<self.map_shape[1])*(pixel_points[:,0]>-self.map_shape[0])*(pixel_points[:,1]>-self.map_shape[1])
            obj_map[pixel_points[mask][:, 0], pixel_points[mask][:, 1]] = 1

        return obj_map
