import os
import shutil
from typing import Optional, Tuple

from habitat_sim.errors import GreedyFollowerError
import matplotlib
import nvtx
import numpy as np
from vlfm.utils.geometry_utils import rho_theta
from home_robot.core.interfaces import DiscreteNavigationAction

from helios.agent.planner.fm2_planner import get_action_goal, get_times
from helios.env.privileged_env import PrivEnv


class GtPlanner:
    def __init__(self, map_shape, pixels_per_meter, image_dir):
        self.map_shape = map_shape
        self.pixels_per_meter = pixels_per_meter
        self.image_dir = image_dir
        self.envs = None

    def reset(self):
        self.start_rot = self.envs.get_robot_rot()

    def set_envs(self, envs: PrivEnv):
        self.envs = envs

    def set_vis_dir(self, scene_id: str, episode_id: str):
        self.vis_dir = os.path.join(self.image_dir, f"{scene_id}_{episode_id}")
        shutil.rmtree(self.vis_dir, ignore_errors=True)
        os.makedirs(self.vis_dir, exist_ok=True)

    @nvtx.annotate("GtPlanner.plan")
    def plan(
        self,
        obstacle_map: np.ndarray,
        goal_map: np.ndarray,
        frontier_map: np.ndarray,
        sensor_pose: np.ndarray,
        found_goal: bool,
        debug: bool = True,
        use_dilation_for_stg: bool = False,
        timestep: int = None,
        seen_map: Optional[np.ndarray] = None,
    ):
        obstacle_map = obstacle_map.astype(bool)
        goal_map = goal_map.astype(bool)
        frontier_map = frontier_map.astype(bool)

        action = self.plan_ours(
            obstacle_map=obstacle_map,
            seen_map=seen_map,
            goal_map=goal_map if found_goal else frontier_map,
            sensor_pose=sensor_pose,
            timestep=timestep,
            collision=False,
        )
        return action, None, None, None

    @nvtx.annotate("GtPlanner.plan_ours")
    def plan_ours(
        self,
        obstacle_map: np.ndarray,
        seen_map: Optional[np.ndarray],
        goal_map: np.ndarray,
        sensor_pose: np.ndarray,
        timestep: int,
        collision: bool,
    ) -> Tuple[DiscreteNavigationAction, bool]:
        if self.envs is None:
            raise ValueError("Env not set for GtNavAgent")

        distances_from_goal = get_times(
            obstacle_map=np.zeros_like(obstacle_map),
            goal_map=goal_map,
            speed=np.ones_like(obstacle_map),
            dx=1/self.pixels_per_meter,
        )

        action_t = 0
        for goal_radius in np.linspace(0, 5, 10):
            goal_map = (goal_radius - 0.5 <= distances_from_goal) & (distances_from_goal <= goal_radius)

            world_t_local_px = sensor_pose[[3, 5]].astype(np.int32)
            world_t_start_px = (sensor_pose[[1, 0]] * self.pixels_per_meter).astype(np.int32)
            local_start_px = world_t_start_px - world_t_local_px
            try:
                local_t_goal_px = get_action_goal(
                    obstacle_map=obstacle_map,
                    seen_map=None,
                    collision_map=None,
                    goal_map=goal_map,
                    start_px=local_start_px,
                    max_displacement=None,
                    min_displacement=None,
                    pixels_per_meter=self.pixels_per_meter,
                    cost='goal',
                    vis_dir=self.vis_dir,
                    timestep=timestep,
                )
            except:
                print(f"GtPlanner.plan_ours().get_action_goal() raised ValueError {goal_radius:.2f}")
                continue
            
            world_t_goal_px = local_t_goal_px + world_t_local_px
            world_t_goal = (world_t_goal_px - self.map_shape[0] // 2)[::-1] / self.pixels_per_meter

            try:
                action_t = self.envs.get_gt_path_action(world_t_goal, self.start_rot).item()
            except:
                print(f"GtPlanner.plan_ours().get_gt_path_action() raised GreedyFollowerError {goal_radius:.2f}")
                continue
            break

        if action_t == 0:
            return DiscreteNavigationAction.STOP
        if action_t == 1:
            return DiscreteNavigationAction.MOVE_FORWARD
        elif action_t == 2:
            return DiscreteNavigationAction.TURN_LEFT
        elif action_t == 3:
            return DiscreteNavigationAction.TURN_RIGHT
        else:
            print(f"Unknown action {action_t} from gt path! Turning instead")
            # turn left or right randomly
            r = np.random.randn()
            if r > 0.5:
                return DiscreteNavigationAction.TURN_LEFT
            return DiscreteNavigationAction.TURN_RIGHT

