import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import habitat_sim
import numpy as np
from habitat import registry
from habitat.core.dataset import Dataset
from habitat.core.env import Env
from habitat.sims.habitat_simulator.actions import HabitatSimActions

# from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower
from omegaconf import DictConfig
from scipy.spatial.transform import Rotation as R
from torch import Tensor, tensor

from home_robot_sim.env.habitat_ovmm_env.habitat_ovmm_env import (
    HabitatOpenVocabManipEnv,
)

if TYPE_CHECKING:
    from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim

import magnum as mn
from frontier_exploration.utils.fog_of_war import reveal_fog_of_war
from habitat.tasks.nav.shortest_path_follower import (
    _quat_to_xy_heading,
    action_to_one_hot,
)
from habitat.tasks.rearrange.rearrange_sim import RearrangeSim
from habitat.tasks.utils import cartesian_to_polar
from habitat.utils.geometry_utils import quaternion_from_coeff, quaternion_rotate_vector
from habitat.utils.visualizations import maps

from helios.env.greedy_geodesic_follower import GreedyGeodesicFollower
# from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower


class ShortestPathFollower:
    r"""Utility class for extracting the action on the shortest path to the
        goal.

    :param sim: HabitatSim instance.
    :param goal_radius: Distance between the agent and the goal for it to be
            considered successful.
    :param return_one_hot: If true, returns a one-hot encoding of the action
            (useful for training ML agents). If false, returns the
            SimulatorAction.
    :param stop_on_error: Return stop if the follower is unable to determine a
                          suitable action to take next.  If false, will raise
                          a habitat_sim.errors.GreedyFollowerError instead
    """

    def __init__(
        self,
        sim: "HabitatSim",
        goal_radius: float,
        return_one_hot: bool = True,
        stop_on_error: bool = True,
    ):
        self._return_one_hot = return_one_hot
        self._sim = sim
        self._goal_radius = goal_radius
        self._follower: Optional[habitat_sim.GreedyGeodesicFollower] = None
        self._current_scene = None
        self._stop_on_error = stop_on_error

    def _build_follower(self):
        if self._current_scene != self._sim.habitat_config.scene:
            self._follower = GreedyGeodesicFollower(
                self._sim.pathfinder,
                self._sim,
                self._goal_radius,
                stop_key=HabitatSimActions.stop,
                forward_key=HabitatSimActions.move_forward,
                left_key=HabitatSimActions.turn_left,
                right_key=HabitatSimActions.turn_right,
                fix_thrashing=True,
                thrashing_threshold=16,
            )
            self._current_scene = self._sim.habitat_config.scene

    def _get_return_value(self, action) -> Union[int, np.ndarray]:
        if self._return_one_hot:
            return action_to_one_hot(action)
        else:
            return action

    def get_next_action(
        self, goal_pos: Union[List[float], np.ndarray]
    ) -> Optional[Union[int, np.ndarray]]:
        """Returns the next action along the shortest path."""
        self._build_follower()
        assert self._follower is not None
        next_action = self._follower.next_action_along(
            goal_pos,
        )

        return self._get_return_value(next_action)

    @property
    def mode(self):
        warnings.warn(".mode is depricated", DeprecationWarning)
        return ""

    @mode.setter
    def mode(self, new_mode: str):
        warnings.warn(".mode is depricated", DeprecationWarning)


class PrivEnv(HabitatOpenVocabManipEnv):
    def __init__(
        self,
        habitat_env: Env,
        config: DictConfig,
        dataset: Dataset,
        goal_radius: float = 0.4,
    ) -> None:
        super().__init__(habitat_env, config, dataset)

        self.shortest_path_follower = ShortestPathFollower(
            sim=self.habitat_env.env.env.habitat_env.sim,
            goal_radius=goal_radius,
            return_one_hot=False,
        )

    def get_gt_path_action(self, goal: np.ndarray, start_rot: np.ndarray) -> Tensor:
        episode = self.habitat_env.env.env.habitat_env.current_episode

        start_pos = episode.start_position

        goal_w = np.array([-goal[1], start_pos[1], -goal[0]])
        goal_w = start_rot @ goal_w.T
        goal_w[0] += episode.start_position[0]
        goal_w[2] += episode.start_position[2]
        best_action = self.shortest_path_follower.get_next_action(goal_w.T)

        return tensor([[best_action]])

    def print_agent_rot(self):
        episode = self.habitat_env.env.env.habitat_env.current_episode
        sim = self.habitat_env.env.env.habitat_env.sim

        agent_idx = 0
        robot = sim.get_agent_data(agent_idx).articulated_agent
        ang_pos = float(robot.base_rot) - np.pi / 2
        curr_quat = robot.sim_obj.rotation
        curr_rotation = [
            curr_quat.vector.x,
            curr_quat.vector.y,
            curr_quat.vector.z,
            curr_quat.scalar,
        ]
        curr_quat = quaternion_from_coeff(curr_rotation)
        # get heading angle
        rot = _quat_to_xy_heading(curr_quat.inverse())
        rot = rot - np.pi / 2
        # convert back to quaternion
        ang_pos = rot[0]
        curr_rot = mn.Quaternion(
            mn.Vector3(0, np.sin(ang_pos / 2), 0), np.cos(ang_pos / 2)
        )
        curr_pos = robot.base_pos

        print("#######################")
        print("CURR ROBOT: ", curr_pos, curr_rot)
        print("EPISODE START: ", episode.start_position, episode.start_rotation)
        print("STATE: ", self.habitat_env.env.env.habitat_env.sim.get_agent_state())
        print("#######################")

    def get_robot_rot(self):
        sim = self.habitat_env.env.env.habitat_env.sim

        agent_idx = 0
        robot = sim.get_agent_data(agent_idx).articulated_agent
        ang_pos = float(robot.base_rot) - np.pi / 2
        curr_quat = robot.sim_obj.rotation
        curr_rotation = [
            curr_quat.vector.x,
            curr_quat.vector.y,
            curr_quat.vector.z,
            curr_quat.scalar,
        ]
        curr_quat = quaternion_from_coeff(curr_rotation)
        # get heading angle
        rot = _quat_to_xy_heading(curr_quat.inverse())
        rot = rot - np.pi / 2
        # convert back to quaternion
        ang_pos = rot[0]
        curr_rot = mn.Quaternion(
            mn.Vector3(0, np.sin(ang_pos / 2), 0), np.cos(ang_pos / 2)
        )

        quat = [
            curr_rot.vector.x,
            curr_rot.vector.y,
            curr_rot.vector.z,
            curr_rot.scalar,
        ]

        return (R.from_quat(quat)).as_matrix()

    def get_topdown_map(self, map_resolution=1000):
        sim = self.habitat_env.env.env.habitat_env.sim
        return maps.get_topdown_map_from_sim(sim, map_resolution=map_resolution)

    def get_topdown_pos(self, td_map):
        sim = self.habitat_env.env.env.habitat_env.sim
        pos = sim.get_agent_state().position
        x, y = maps.to_grid(pos[2], pos[0], td_map.shape, sim=sim)
        xy_p = np.array([x, y])

        return xy_p

    def get_topdown_map_info(self, fow_mask):
        td_map = self.get_topdown_map(map_resolution=fow_mask.shape[1])
        xy_p = self.get_topdown_pos(td_map)
        rot_m = self.get_robot_rot()
        yaw = R.from_matrix(rot_m).as_euler("yxz")[0] - np.pi

        info = {
            "map": td_map,
            "fog_of_war_mask": fow_mask,
            "agent_map_coord": [xy_p],
            "agent_angle": [yaw],
        }

        return info

    def reveal_fog_of_war(self, fow_mask):
        td_map = self.get_topdown_map()
        xy_p = self.get_topdown_pos(td_map)
        rot_m = self.get_robot_rot()
        yaw = R.from_matrix(rot_m).as_euler("yxz")[0] - np.pi

        return reveal_fog_of_war(td_map, fow_mask, xy_p, yaw)
