import math
import os
import shutil
from typing import List, Literal, Optional, Tuple
import cv2 
import nvtx
import numpy as np
from scipy.spatial.transform import Rotation
import skfmm
import imageio
import matplotlib

import home_robot.utils.pose as pu
from home_robot.core.interfaces import (
    ContinuousNavigationAction,
    DiscreteNavigationAction,
)
import skimage.draw
from helios.agent.utils.visualization import generate_times_image
from helios.agent.utils.utils import gps_to_px, get_bounds


class Fm2Planner:
    def __init__(
        self,
        map_shape: Tuple[int, int],
        pixels_per_meter: int,
        min_displacement: float,
        max_displacement: float,
        pref_displacement: float,
        min_turn_degrees: float,
        max_turn_degrees: float,
        image_dir: str,
        window_size: int,
        visualize: bool,
    ):
        self.map_shape = map_shape
        self.pixels_per_meter = pixels_per_meter
        self.min_displacement = min_displacement
        self.max_displacement = max_displacement
        self.pref_displacement = pref_displacement
        self.min_turn_degrees = min_turn_degrees
        self.max_turn_degrees = max_turn_degrees
        self.image_dir = image_dir
        self.window_size = window_size
        self.visualize = visualize

    def reset(self):
        self.vis_dir = None

    def set_vis_dir(self, scene_id: str, episode_id: str):
        self.vis_dir = os.path.join(self.image_dir, f"{scene_id}_{episode_id}")
        shutil.rmtree(self.vis_dir, ignore_errors=True)
        os.makedirs(self.vis_dir, exist_ok=True)

    def plan(
        self,
        obstacle_map: np.ndarray,
        goal_map: np.ndarray,
        frontier_map: np.ndarray,
        sensor_pose: np.ndarray,
        found_goal: bool,
        debug: bool = True,
        use_dilation_for_stg: bool = False,
        timestep: Optional[int] = None,
        seen_map: Optional[np.ndarray] = None,
    ):
        obstacle_map = obstacle_map.astype(bool)
        goal_map = goal_map.astype(bool)
        frontier_map = frontier_map.astype(bool)

        distances_from_goal = get_times(
            obstacle_map=None,
            goal_map=goal_map,
            speed=None,
            dx=1/self.pixels_per_meter,
        )

        action = DiscreteNavigationAction.STOP
        for goal_radius in np.linspace(0, 5, 10):
            goal_map = (goal_radius - 0.5 <= distances_from_goal) & (distances_from_goal <= goal_radius)
            map_xyt_start = np.array([
                sensor_pose[0] - sensor_pose[5] / self.pixels_per_meter,
                sensor_pose[1] - sensor_pose[3] / self.pixels_per_meter,
                np.radians(sensor_pose[2])
            ])
            try:
                action = self.plan_to_map(
                    obstacle_map=obstacle_map,
                    seen_map=seen_map,
                    goal_map=goal_map if found_goal else frontier_map,
                    map_xyt_start=map_xyt_start,
                    goal_theta=None,
                    timestep=timestep,
                )
                break
            except Exception as e:
                print(f"Fm2Planner.plan() error goal_radius: {goal_radius} error: {e}")
                continue
        
        return action, None, None, None
    
    def plan_to_pose(
        self,
        obstacle_map: np.ndarray,
        seen_map: Optional[np.ndarray],
        xyt_start: np.ndarray,
        xyt_goal: np.ndarray, # [x, y] or [x, y, theta]
        timestep: Optional[int],
        gaze_points: Optional[List[Tuple[float, float]]] = None,
        max_turn_degrees: Optional[float] = None,
    ):
        map_xyt_start = xyt_start.copy()
        map_xyt_start[:2] += self.map_shape[0] // 2 // self.pixels_per_meter
        goal_px = gps_to_px(xyt_goal[:2], self.pixels_per_meter, self.map_shape[0])
        goal_map = np.zeros_like(obstacle_map)
        goal_map[tuple(goal_px)] = 1
        if gaze_points is not None:
            gaze_map = np.zeros_like(obstacle_map)
            pxs = gps_to_px(gaze_points, self.pixels_per_meter, self.map_shape[0])
            gaze_map[pxs[:, 0], pxs[:, 1]] = 1
        else: 
            gaze_map = None

        return self.plan_to_map(
            obstacle_map=obstacle_map,
            seen_map=seen_map,
            map_xyt_start=map_xyt_start,
            goal_map=goal_map,
            goal_theta=xyt_goal[2] if len(xyt_goal) == 3 else None,
            timestep=timestep,
            gaze_map=gaze_map,
            max_turn_degrees=max_turn_degrees,
        )

    @nvtx.annotate("Fm2Planner.plan_to_map")
    def plan_to_map(
        self,
        obstacle_map: np.ndarray,
        seen_map: Optional[np.ndarray],
        map_xyt_start: np.ndarray,
        goal_map: np.ndarray,
        goal_theta: Optional[float],
        timestep: Optional[int],
        gaze_map: Optional[np.ndarray] = None,
        max_turn_degrees: Optional[float] = None,
    ) -> Tuple[DiscreteNavigationAction, bool]:
        map_t_start_px = np.rint(map_xyt_start[[1, 0]] * self.pixels_per_meter).astype(np.int32)
        obstacle_bounds = get_bounds(obstacle_map, 0)
        goal_bounds = get_bounds(goal_map, 0)
        y0, y1, x0, x1 = (
            min(obstacle_bounds[0], goal_bounds[0]-1, map_t_start_px[0]-1),
            max(obstacle_bounds[1], goal_bounds[1]+1, map_t_start_px[0]+2),
            min(obstacle_bounds[2], goal_bounds[2]-1, map_t_start_px[1]-1),
            max(obstacle_bounds[3], goal_bounds[3]+1, map_t_start_px[1]+2),
        )

        obstacle_map = obstacle_map[y0:y1, x0:x1]
        goal_map = goal_map[y0:y1, x0:x1]
        gaze_map = gaze_map[y0:y1, x0:x1] if gaze_map is not None else None
        map_xyt_start = map_xyt_start.copy()
        map_xyt_start[:2] -= np.array([x0, y0]) / self.pixels_per_meter
        map_t_start_px = np.rint(map_xyt_start[[1, 0]] * self.pixels_per_meter).astype(np.int32)
        assert map_t_start_px[0] < goal_map.shape[0] and map_t_start_px[1] < goal_map.shape[1],\
            f"start not in bound: map_t_start_px: {map_t_start_px} goal_map.shape: {goal_map.shape} y0, y1, x0, x1: {y0, y1, x0, x1}"\
            f" obstacle_bounds: {obstacle_bounds} goal_bounds: {goal_bounds}"

        # get short term goal
        map_t_goal_px = get_action_goal(
            obstacle_map=obstacle_map,
            goal_map=goal_map,
            map_t_start_px=map_t_start_px,
            max_displacement=self.max_displacement,
            min_displacement=self.min_displacement,
            pref_displacement=self.pref_displacement,
            pixels_per_meter=self.pixels_per_meter,
            vis_dir=self.vis_dir,
            timestep=timestep if self.visualize else None,
            theta=map_xyt_start[2],
            gaze_map=gaze_map,
        )
        if map_t_goal_px is None:
            return DiscreteNavigationAction.STOP
        map_T_start = Rotation.from_euler('z', map_xyt_start[2], degrees=False).as_matrix()
        map_T_start[:2, 2] = map_xyt_start[:2]
        map_t_goal = np.array([
            map_t_goal_px[1] / self.pixels_per_meter,
            map_t_goal_px[0] / self.pixels_per_meter,
            1
        ])
        start_goal = np.linalg.inv(map_T_start) @ map_t_goal

        # get action
        action_x = start_goal[0]
        action_t = np.arctan2(start_goal[1], start_goal[0])
        if max_turn_degrees is None:
            max_turn_degrees = self.max_turn_degrees
        if np.linalg.norm(start_goal[:2]) < self.min_displacement / 2 and goal_theta is not None:
            action_t = goal_theta - map_xyt_start[2]
            if action_t > np.pi:
                action_t -= 2 * np.pi
            elif action_t < -np.pi:
                action_t += 2 * np.pi

        if np.abs(action_t) > np.radians(self.min_turn_degrees) / 2:
            action_x = 0
            action_t = np.sign(action_t) * np.clip(np.abs(action_t), np.radians(self.min_turn_degrees), np.radians(max_turn_degrees))
        elif np.abs(action_x) > self.min_displacement / 2:
            action_x = np.clip(action_x, self.min_displacement, self.max_displacement)
            action_t = 0
        else:
            action_x = 0
            action_t = 0

        action = ContinuousNavigationAction([action_x, 0, action_t])
        if action.xyt == [0, 0, 0]:
            action = DiscreteNavigationAction.STOP

        return action


@nvtx.annotate("get_action_goal")
def get_action_goal(
    obstacle_map: np.ndarray,
    goal_map: np.ndarray,
    map_t_start_px: Tuple[int],
    max_displacement: float,
    min_displacement: float,
    pref_displacement: Optional[float],
    pixels_per_meter: int,
    vis_dir: str,
    timestep: Optional[int],
    theta: float = 0,
    gaze_map: Optional[np.ndarray] = None,
) -> Tuple[Tuple[int, int], np.ndarray, bool, bool]:
    map_t_start_px = np.array(map_t_start_px)
    if goal_map[map_t_start_px[0], map_t_start_px[1]]:
        map_t_goal_px = map_t_start_px
        times_from_goal = obstacle_map.astype(float)
        circle2 = None
    else:
        pref_displacement = pref_displacement or min_displacement

        distances_from_obstacles = get_times(
            obstacle_map=None,
            goal_map=obstacle_map,
            speed=None,
            dx=1/pixels_per_meter
        )

        times_from_goal = get_times(
            obstacle_map=obstacle_map,
            goal_map=goal_map,
            speed=distances_from_obstacles,
            dx=1/pixels_per_meter
        )

        path = []
        px_curr = map_t_start_px
        for _ in range(int(max_displacement * pixels_per_meter)):
            path.append(px_curr)
            if goal_map[px_curr[0], px_curr[1]]:
                break
            t, b = np.clip([px_curr[0]-1, px_curr[0]+2], 0, times_from_goal.shape[0])
            l, r = np.clip([px_curr[1]-1, px_curr[1]+2], 0, times_from_goal.shape[1])
            px_curr = px_curr + np.unravel_index(
                np.argmin(times_from_goal[t:b,l:r]),
                (3, 3)
            ) - 1
        map_t_goal_px = path[-1]
        
        dist = int(round(np.linalg.norm(map_t_goal_px - map_t_start_px)))
        if dist > (pref_displacement + min_displacement) * pixels_per_meter:
            px_displacement = int(round(pref_displacement * pixels_per_meter))
        else:
            px_displacement = max(int(round(min_displacement * pixels_per_meter)), dist)

        circle1 = np.stack(
            skimage.draw.circle_perimeter(map_t_start_px[0], map_t_start_px[1], radius=px_displacement),
            -1
        )
        # goals_within_distance = np.argwwhere(goal_map)
        # goals_within_distance = goals_within_distance[
        #     np.linalg.norm(goals_within_distance - map_t_start_px, axis=-1) < 2 * min_displacement * pixels_per_meter
        # ]
        # if len(goals_within_distance):
        #     circle1 = np.concatenate([
        #         circle1,
        #         goals_within_distance
        #     ], axis=0)
        if dist <= pref_displacement * pixels_per_meter:
            circle1 = np.concatenate([
                circle1,
                map_t_goal_px[None]
            ], axis=0)
        circle2 = list(filter(
            lambda x: (
                not np.any(obstacle_map[skimage.draw.line(map_t_start_px[0], map_t_start_px[1], x[0], x[1])])
                and (
                    np.linalg.norm(x - px_curr) > min_displacement * pixels_per_meter
                    or goal_map[x[0], x[1]]
                )
            ),
            circle1
        )) # ray not in obstacle and dst at least min_displacement away from goal
        if all(times_from_goal[x[0], x[1]] > times_from_goal[tuple(map_t_start_px)] for x in circle2):
            print('all times_from_goal > start')
            circle2 = list(filter(
                lambda x: (
                    not obstacle_map[x[0], x[1]]
                    and (
                        np.linalg.norm(x - px_curr) > min_displacement * pixels_per_meter
                        or goal_map[x[0], x[1]]
                    )
                ),
                circle1
            ))
        circle3 = sorted(circle2, key=lambda x: times_from_goal[x[0], x[1]])
        # if any(goal_map[x[0], x[1]] for x in circle2):
        #     circle2 = list(filter(lambda x: goal_map[x[0], x[1]], circle2))
        # def aa_obstacle(x):
        #     ys, xs, vals = skimage.draw.line_aa(map_t_start_px[0], map_t_start_px[1], x[0], x[1])
        #     return np.sum(vals * obstacle_map[ys, xs])
        # has_aa_obstacles = np.array([aa_obstacle(x) > 0 for x in circle2])
        # if np.all(has_aa_obstacles):
        #     circle3 = sorted(circle2, key=lambda x: aa_obstacle(x))
        # elif np.any(has_aa_obstacles):
        #     circle3 = list(filter(lambda x: aa_obstacle(x) == 0, circle2))
        #     circle3 = sorted(circle3, key=lambda x: times_from_goal[x[0], x[1]])
        # else:
        #     circle3 = sorted(circle2, key=lambda x: times_from_goal[x[0], x[1]])

        map_t_goal_px = circle3[0]

    if timestep is not None:
        planner_img = generate_times_image(times_from_goal)
        planner_img[obstacle_map] = [0, 0, 0]
        planner_img[goal_map] = [0, 0, 255]
        if circle2 is not None:
            planner_img[[y for y, x in circle2], [x for y, x in circle2]] = [255, 0, 0]
        # planner_img = draw_arrowhead(planner_img, map_t_start_px, theta, arrow_size=15, color=(255, 0, 0))
        planner_img[tuple(map_t_start_px)] = [255, 0, 0] 
        planner_img[tuple(map_t_goal_px)] = [0, 255, 0]        
        if gaze_map is not None:
            for y, x in np.argwhere(gaze_map):
                cv2.drawMarker(planner_img, (x, y), (255, 255, 0), markerType=cv2.MARKER_TILTED_CROSS, markerSize=5, thickness=1)
        # if goal_map is not None:
        #     for y, x in np.argwhere(goal_map):
        #         cv2.drawMarker(planner_img, (x, y), (255, 0, 0), markerType=cv2.MARKER_TILTED_CROSS, markerSize=5, thickness=1)

        window_size = 100
        y0, x0 = np.maximum(map_t_start_px - window_size // 2, 0)
        y1, x1 = y0 + window_size, x0 + window_size
        # y0, y1, x0, x1 = 0, planner_img.shape[0], 0, planner_img.shape[1]
        planner_img = planner_img[y1:y0:-1, x0:x1]
        imageio.imwrite(f'{vis_dir}/planner_snapshot_{timestep}.png', planner_img)
    return map_t_goal_px


@nvtx.annotate("get_times")
def get_times(
    obstacle_map: Optional[np.ndarray],
    goal_map: np.ndarray,
    speed: Optional[np.ndarray],
    dx: float
) -> np.ndarray:
    assert (isinstance(goal_map, np.ndarray) and goal_map.dtype == bool) or isinstance(goal_map, tuple)

    if obstacle_map is None:
        if speed is None:
            times = cv2.distanceTransform((~goal_map).astype(np.uint8), cv2.DIST_L2, 5)
            return times * dx
        else:
            obstacle_map = np.zeros_like(goal_map)
        
    phi = ~obstacle_map
    phi[0] = 0
    phi[-1] = 0
    phi[:, 0] = 0
    phi[:, -1] = 0
    phi = np.ma.masked_values(phi, 0)
    phi[goal_map] = 0 # 0: goal, 1: traversable, masked: obstacle
    try:
        if speed is None:
            with nvtx.annotate('skfmm.distance'):
                times = skfmm.distance(phi=phi, dx=dx)
        else:
            with nvtx.annotate("skfmm.travel_time"):
                times = skfmm.travel_time(phi=phi, speed=speed, dx=dx)
    except Exception as e:
        planner_img = np.zeros((obstacle_map.shape[0], obstacle_map.shape[1], 3), dtype=np.uint8)
        planner_img[obstacle_map] = [255, 0, 0]
        planner_img[goal_map] = [0, 0, 255]
        planner_img[obstacle_map & goal_map] = [0, 255, 0]
        imageio.imwrite(f'fm2_planner_get_times_error.png', planner_img)
        raise e
    times = np.ma.filled(times, np.inf)
    times[goal_map] = 0

    return times

def draw_arrowhead(planner_img, start_px, theta, arrow_size=10, color=(255, 0, 0)):

    half_size = arrow_size // 4
    arrow_pts = np.array([
        [0, arrow_size//2],    
        [-half_size, -half_size],  
        [half_size, -half_size]   
    ], dtype=np.float32)

    rotation_matrix = cv2.getRotationMatrix2D((0, 0), np.degrees(-theta+np.pi/2), 1.0)
    arrow_pts_rotated = cv2.transform(np.array([arrow_pts]), rotation_matrix)[0]

    arrow_pts_translated = arrow_pts_rotated + np.array(start_px)

    arrow_pts_int = np.int32(arrow_pts_translated)
    cv2.fillConvexPoly(planner_img, arrow_pts_int, color)

    return planner_img