import numpy as np
import cv2

import matplotlib.pyplot as plt
from typing import Any, Dict, Optional
from home_robot.core.interfaces import Observations

import imageio
import skimage.draw
from helios.agent.utils.utils import gps_to_px, px_to_gps, xy_to_px
from helios.agent.utils.visualization import generate_times_image
from helios.agent.planner.fm2_planner import get_times

def calculate_instance_heightmap(instmap, height_map):
    min_area = 200
    hmap = height_map[instmap]

    #divide object instance into equal-sized bins and take largest bin as the "flat" height
    hrange = np.max(hmap)-np.min(hmap)
    bin_size = 0.05 # size variance allowed
    n_bins = int(np.ceil(hrange/bin_size))
    if n_bins ==0:
        h = np.mean(hmap)
        if h < 0.45 or h > 1.05:
            return False, None
        if np.sum(instmap) > min_area:
            return True, instmap
        else:
            return False, None
    hist = np.histogram(hmap, n_bins)
    hi = np.argmax(hist[0])
    h = hist[1][hi]

    # print("H: ", h, hist[0][hi], hrange)

    #Check allowable height range
    if h < 0.45 or h > 1.05:
        return False, None
    
    # print("n instmap size pre h: ", np.sum(instmap))

    #Find heightmap for instance within height range
    instmap[height_map>h+bin_size] = False
    instmap[height_map<h] = False

    # print("n instmap size after h: ", np.sum(instmap))

    #Find largest connected component for above
    output = cv2.connectedComponentsWithStats(
        instmap.astype(np.uint8)*255, 8, cv2.CV_32S)
    (numLabels, labels, stats, _) = output


    good_mask = np.zeros(height_map.shape, dtype=bool)

    # print("n labels: ", numLabels)

    for i in range(1,numLabels):
        area = stats[i, cv2.CC_STAT_AREA]
        # print("AREA: ", area)
        if area > min_area:
            good_mask[labels==i] = True

    #ensure flat surface is large enough
    if np.sum(good_mask) > 0:
        # print(instmap)
        # plt.imsave("instmap_mask_h.png", instmap.astype(np.uint8)*255)
        # print(good_mask)
        # plt.imsave("recept_mask.png", good_mask.astype(np.uint8)*255)
        return True, good_mask
    else:
        return False, None
    
def get_waypose(
    semantic_mask: np.ndarray,
    instance_height: float,
    obs: Observations,
    info: Dict[str, Any],
    min_goal_navigable_dist: Optional[float],
    min_goal_waypose_dist: Optional[float],
    max_goal_waypose_dist: Optional[float],
    height_map: Optional[np.ndarray] = None
):
    y0, y1, x0, x1 = info['bounds']

    # get goal location
    dist_from_outside = get_times(
        obstacle_map=None,
        goal_map=~semantic_mask[y0:y1, x0:x1],
        speed=None,
        dx=1 / info['pixels_per_meter'],
    )
    dist_from_navigable = get_times(
        obstacle_map=None,
        goal_map=info['navigable_map'][y0:y1, x0:x1],
        speed=None,
        dx=1 / info['pixels_per_meter'],
    )
    dist_from_outside[dist_from_outside == np.inf] = 0
    if min_goal_navigable_dist is not None:
        dist_from_outside[dist_from_navigable > min_goal_navigable_dist] = 0
    px_goal = np.array(np.unravel_index(np.argmax(dist_from_outside), dist_from_outside.shape))
    px_start = gps_to_px(obs.gps, info['pixels_per_meter'], info['map_size']) - np.array([y0, x0])

    # get closest navigable location
    dist_to_goal = get_times(
        obstacle_map=info['wall_map'][y0:y1, x0:x1],
        goal_map=tuple(px_goal),
        speed=None,
        dx=1 / info['pixels_per_meter'],
    )
    # todo: replace with floodfill check instead
    # if not reachable from here, ignore wall map
    if dist_to_goal[tuple(px_start)] == np.inf:
        dist_to_goal = get_times(
            obstacle_map=np.zeros((y1 - y0, x1 - x0), dtype=bool),
            goal_map=tuple(px_goal),
            speed=None,
            dx=1 / info['pixels_per_meter'],
        )
    dist_to_curr = get_times(
        obstacle_map=info['dilated_obstacle_map'][y0:y1, x0:x1],
        goal_map=tuple(px_start),
        speed=None,
        dx=1 / info['pixels_per_meter'],
    )
    dist_to_goal[dist_to_curr == np.inf] = np.inf
    if np.min(dist_to_goal) == np.inf:
        state_map = generate_times_image(
            info['dilated_obstacle_map'][y0:y1, x0:x1].astype(np.float32)
            + semantic_mask[y0:y1, x0:x1].astype(np.float32)
        )
        state_map[tuple(px_start)] = [0, 255, 0]
        imageio.imwrite('state_map.png', state_map)
        imageio.imwrite('dist_from_outside.png', generate_times_image(dist_from_outside))
        imageio.imwrite('dist_to_curr.png', generate_times_image(dist_to_curr))
        imageio.imwrite('dist_to_goal.png', generate_times_image(dist_to_goal))
        raise RuntimeError(
            f'[get_waypose] Goal not reachable before goal limiting. '
            'debug visual saved to state_map.png, dist_from_outside.png, dist_to_curr.png, and dist_to_goal.png'
        )
    if min_goal_waypose_dist is not None:
        dist_to_goal[dist_to_goal < min_goal_waypose_dist] = np.inf
    if max_goal_waypose_dist is not None:
        dist_to_goal[dist_to_goal > max_goal_waypose_dist] = np.inf
    if np.min(dist_to_goal) == np.inf:
        state_map = generate_times_image(
            info['dilated_obstacle_map'][y0:y1, x0:x1].astype(np.float32)
            + semantic_mask[y0:y1, x0:x1].astype(np.float32)
        )
        state_map[tuple(px_start)] = [0, 255, 0]
        # imageio.imwrite('state_map.png', state_map)
        # imageio.imwrite('dist_from_outside.png', generate_times_image(dist_from_outside))
        # imageio.imwrite('dist_to_curr.png', generate_times_image(dist_to_curr))
        # imageio.imwrite('dist_to_goal.png', generate_times_image(dist_to_goal))
        raise RuntimeError(
            f'[get_waypose] Goal not reachable after goal limiting. '
            'debug visual saved to state_map.png, dist_from_outside.png, dist_to_curr.png, and dist_to_goal.png'
        )
    px_waypoint = np.array(np.unravel_index(np.argmin(dist_to_goal), dist_to_goal.shape))
    waypoint = px_to_gps(px_waypoint + [y0, x0], info['pixels_per_meter'], info['map_size'])

    if not height_map is None: #hack for when we don't choose an instance so don't have height
        px_window_half_size = 10 #5
        obj_px = px_waypoint + np.array([
                    -np.cos(np.arctan2(*(px_goal - px_waypoint)) - np.radians(10)) * 0.7 * 40,
                    np.sin(np.arctan2(*(px_goal - px_waypoint)) - np.radians(10)) * 0.7 * 40,
                ]).astype(np.int32)
        instance_height = np.max(height_map[
            obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
            obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
        ]) 

    waypose = np.array([
        waypoint[0],
        waypoint[1],
        np.arctan2(*(px_goal - px_waypoint)) - np.radians(10),
        np.arctan2(min(instance_height - 1.3, 0), np.linalg.norm(px_goal - px_waypoint) / info['pixels_per_meter'])
    ])
    return waypose

def get_wayposes(
    instance_mask: np.ndarray,
    instance_height: float,
    obs: Observations,
    info: Dict[str, Any],
    max_dist_from_instance: int,
    min_dist_from_instance: int,
    min_dist_between_waypoints: int,
    height_map: Optional[np.ndarray] = None
):
    y0, y1, x0, x1 = info['bounds']

    # distance from instance and closest instance pixel
    dists, labels = cv2.distanceTransformWithLabels(
        src=(~instance_mask[y0:y1, x0:x1]).astype(np.uint8),
        distanceType=cv2.DIST_L2,
        maskSize=5,
        labelType=cv2.DIST_LABEL_PIXEL
    )
    label_coordinates = np.concatenate([[[0, 0]], np.argwhere(instance_mask[y0:y1, x0:x1])], axis=0)

    # waypixels from pixels of biggest contour of instance + max distance
    px_max_dist_from_instance = int(round(max_dist_from_instance * info['pixels_per_meter']))
    contours, hierarchy = cv2.findContours(
        image=(dists <= px_max_dist_from_instance).astype(np.uint8),
        mode=cv2.RETR_EXTERNAL,
        method=cv2.CHAIN_APPROX_NONE
    )
    waypixels = sorted(contours, key=lambda x: len(x))[0][:, 0, ::-1]
    # [[[x1, y1]], [[x2, y2]], ...] -> [[i1, j1], [i2, j2], ...]

    # filter waypixels too close to instance
    waypixels = waypixels[dists[waypixels[:, 0], waypixels[:, 1]] > min_dist_from_instance * info['pixels_per_meter']]

    # filter waypixels unnavigable or too close to other waypixels
    filtered = []
    for waypixel in waypixels: # [i, j]
        dist_from_first = np.inf if len(filtered) == 0 else np.linalg.norm(waypixel - filtered[0])
        dist_from_last = np.inf if len(filtered) == 0 else np.linalg.norm(waypixel - filtered[-1])
        if (
            dist_from_first > min_dist_between_waypoints * info['pixels_per_meter']
            and dist_from_last > min_dist_between_waypoints * info['pixels_per_meter']
            and info['navigable_map'][waypixel[0] + y0, waypixel[1] + x0]
        ):
            filtered.append(waypixel)
    waypixels = filtered
    if len(waypixels) == 0:
        return waypixels
    
    # filter waypixels without line of sight
    px_start = gps_to_px(obs.gps, info['pixels_per_meter'], info['map_size']) - [y0, x0]
    waypixels = list(filter(
        lambda waypixel: not np.any(info['wall_map'][skimage.draw.line(
            *(label_coordinates[labels[tuple(waypixel)]] + [y0, x0]),
            *(waypixel + [y0, x0]),
        )]),
        waypixels
    ))
    if len(waypixels) == 0:
        return waypixels

    # invert if last is closer than first
    dist_to_start = get_times(
        obstacle_map=info['wall_map'][y0:y1, x0:x1],
        goal_map=tuple(px_start),
        speed=None,
        dx=1 / info['pixels_per_meter'],
    )
    if dist_to_start[tuple(waypixels[0])] > dist_to_start[tuple(waypixels[-1])]:
        waypixels = waypixels[::-1]
    
    # filter waypixels much farther than other waypoints
    dist_to_goal = get_times(
        obstacle_map=info['wall_map'][y0:y1, x0:x1],
        goal_map=instance_mask[y0:y1, x0:x1],
        speed=None,
        dx=1 / info['pixels_per_meter'],
    )
    if dist_to_goal[tuple(px_start)] == np.inf:
        dist_to_goal = get_times(
            obstacle_map=np.zeros((y1 - y0, x1 - x0), dtype=bool),
            goal_map=instance_mask[y0:y1, x0:x1],
            speed=None,
            dx=1 / info['pixels_per_meter'],
        )
    dist_to_goal[~info['navigable_map'][y0:y1, x0:x1]] = np.inf
    min_dist = min(dist_to_goal[tuple(waypixel)] for waypixel in waypixels)
    assert min_dist != np.inf, '[get_waypose] no reachable wayposes found.'
    waypixels = list(filter(lambda waypixel: dist_to_goal[tuple(waypixel)] < min_dist + 0.5, waypixels))

    if not height_map is None: #hack for when we don't choose an instance so don't have height
        px_window_half_size = 10 #5
        obj_px = waypixels[0] + np.array([
                    -np.cos(np.arctan2(*(label_coordinates[labels[tuple(waypixel)]] - waypixel) - np.radians(10))) * 0.7 * 40,
                    np.sin(np.arctan2(*(label_coordinates[labels[tuple(waypixel)]] - waypixel) - np.radians(10))) * 0.7 * 40,
                ]).astype(np.int32)
        instance_height = np.max(height_map[
            obj_px[0]-px_window_half_size:obj_px[0]+px_window_half_size,
            obj_px[1]-px_window_half_size:obj_px[1]+px_window_half_size,
        ]) 
    
    # convert waypixels to wayposes
    wayposes = [
        np.concatenate([
            px_to_gps(waypixel + [y0, x0], info['pixels_per_meter'], info['map_size']),
            [np.arctan2(*(label_coordinates[labels[tuple(waypixel)]] - waypixel))],
            [np.arctan2(min(instance_height - 1.3, 0), np.linalg.norm(label_coordinates[labels[tuple(waypixel)]] - waypixel) / info['pixels_per_meter'])]
        ])
        for waypixel in waypixels
    ]

    return wayposes