import copy
from typing import Dict, Tuple

import numpy as np
import numpy.typing as npt
from nuplan.common.actor_state.state_representation import Point2D
from nuplan.common.actor_state.tracked_objects import TrackedObject
from nuplan.common.actor_state.tracked_objects_types import (
    AGENT_TYPES,
    TrackedObjectType,
)

from src.planners.pdm_planner.utils.pdm_enums import (
    BBCoordsIndex,
)
from src.planners.pdm_planner.utils.pdm_geometry_utils import (
    normalize_angle,
)


MAX_DYNAMIC_OBJECTS: Dict[TrackedObjectType, int] = {
    TrackedObjectType.VEHICLE: 50,
    TrackedObjectType.PEDESTRIAN: 25,
    TrackedObjectType.BICYCLE: 10,
}

MAX_STATIC_OBJECTS: int = 50


class PDMObjectManager:
    """Class that stores and sorts tracked objects around the ego-vehicle."""

    def __init__(
        self,
    ):
        """Constructor of PDMObjectManager."""

        # all objects
        self._unique_objects: Dict[str, TrackedObject] = {}

        # dynamic objects
        self._dynamic_object_tokens = {key: [] for key in MAX_DYNAMIC_OBJECTS.keys()}
        self._dynamic_object_coords = {key: [] for key in MAX_DYNAMIC_OBJECTS.keys()}
        self._dynamic_object_dxy = {key: [] for key in MAX_DYNAMIC_OBJECTS.keys()}

        # static objects
        self._static_object_tokens = []
        self._static_object_coords = []

    @property
    def unique_objects(self) -> Dict[str, TrackedObject]:
        """
        Getter of unique_objects
        :return: Dictionary of uniquely tracked objects
        """
        return self._unique_objects
    
    def trans(self, center, heading, length, width):
        half_pi = np.pi / 2.0
        translation = np.stack(
            [
                (width * np.cos(heading + half_pi)) + (length * np.cos(heading)),
                (width * np.sin(heading + half_pi)) + (length * np.sin(heading)),
            ],
            axis=-1
        )
        return center + translation
    
    def pred_box(self, center, heading, length, width):
        fl = self.trans(center, heading, length, width)
        rl = self.trans(center, heading, -length, width)
        rr = self.trans(center, heading, -length, -width)
        fr = self.trans(center, heading, length, -width)
        return np.stack([fl, rl, rr, fr, center], axis=-2)

    def add_object(self, object: TrackedObject, object_preds= None) -> None:
        """
        Add object to manager and sort category (dynamic/static)
        :param object: any tracked object
        :param object_preds: np.array of predictions
        """
        self._unique_objects[object.track_token] = object

        coords_list = [
            [corner.x, corner.y] for corner in copy.deepcopy(object.box.all_corners())
        ]
        coords_list.append([object.center.x, object.center.y])
        #[5, 2] for current box
        coords: np.ndarray = np.array(coords_list, dtype=np.float64)

        if object_preds is not None:
            # print(object.center.x, object.center.y)
            # print(object_preds[0, :2])
            # assert 1==0
            length = object.box.length
            width = object.box.width
            pred_center, pred_heading  =  object_preds[..., :2], object_preds[..., -1]

            pred_coords = self.pred_box(pred_center, pred_heading, length/2, width /2)
            pred_coords = pred_coords.transpose(1, 0, 2)
            # print('pred:',pred_coords.shape)
            # pred_coords = np.concatenate([coords[:, None, :], pred_coords], axis=-2)

        if object.tracked_object_type in AGENT_TYPES:
            if object_preds is not None:
                self._add_dynamic_object(
                    object.tracked_object_type, object.track_token, coords, pred_coords
                )
            else:
                velocity = object.velocity
                velocity_angle = np.arctan2(velocity.y, velocity.x)
                agent_drives_forward = (
                    np.abs(normalize_angle(object.center.heading - velocity_angle))
                    < np.pi / 2
                )

                track_heading = (
                    object.center.heading
                    if agent_drives_forward
                    else normalize_angle(object.center.heading + np.pi)
                )

                dxy = np.array(
                    [
                        np.cos(track_heading) * velocity.magnitude(),
                        np.sin(track_heading) * velocity.magnitude(),
                    ],
                    dtype=np.float64,
                ).T  # x,y velocity [m/s]

                dxys = np.repeat(dxy[None, :], 81, axis=0)
                dxys = dxys * (np.arange(81)[:, None])*0.1
                pred_coords = coords[:, None, :2] + dxys[None, :, :2]
                # print(pred_coords.shape)

                self._add_dynamic_object(
                    object.tracked_object_type, object.track_token, coords, pred_coords
                )

        else:
            self._add_static_object(
                object.tracked_object_type, object.track_token, coords
            )

    def get_nearest_objects(self, position: Point2D) -> Tuple:
        """
        Retrieve nearest k objects depending on category.
        :param position: global map position
        :return: tuple containing tokens, coords, and dynamic information of objects
        """
        dynamic_object_tokens, dynamic_object_coords_list, dynamic_object_dxy_list = (
            [],
            [],
            [],
        )

        for dynamic_object_type in MAX_DYNAMIC_OBJECTS.keys():
            (
                dynamic_object_tokens_,
                dynamic_object_coords_,
                dynamic_object_dxy_,
            ) = self._get_nearest_dynamic_objects(position, dynamic_object_type)

            if dynamic_object_coords_.ndim < 3:
                continue

            dynamic_object_tokens.extend(dynamic_object_tokens_)
            dynamic_object_coords_list.append(dynamic_object_coords_)
            dynamic_object_dxy_list.append(dynamic_object_dxy_)

        if len(dynamic_object_coords_list) > 0:
            dynamic_object_coords = np.concatenate(
                dynamic_object_coords_list, axis=0, dtype=np.float64
            )
            dynamic_object_dxy = np.concatenate(
                dynamic_object_dxy_list, axis=0, dtype=np.float64
            )
        else:
            dynamic_object_coords = np.array([], dtype=np.float64)
            dynamic_object_dxy = np.array([], dtype=np.float64)

        static_object_tokens, static_object_coords = self._get_nearest_static_objects(
            position, None
        )

        return (
            static_object_tokens,
            static_object_coords,
            dynamic_object_tokens,
            dynamic_object_coords,
            dynamic_object_dxy,
        )

    def _add_dynamic_object(
        self,
        type: TrackedObjectType,
        token: str,
        coords: npt.NDArray[np.float64],
        dxy: npt.NDArray[np.float64],
    ) -> None:
        """
        Adds dynamic obstacle to the manager.
        :param type: Object type (vehicle, pedestrian, etc.)
        :param token: Temporally consistent object identifier
        :param coords: Bounding-box coordinates
        :param dxy: velocity (x,y) [m/s]
        """
        self._dynamic_object_tokens[type].append(token)
        self._dynamic_object_coords[type].append(coords)
        self._dynamic_object_dxy[type].append(dxy)

    def _add_static_object(
        self,
        type: TrackedObjectType,
        token: str,
        coords: npt.NDArray[np.float64],
    ) -> None:
        """
        Adds static obstacle to manager.
        :param type: Object type (e.g. generic, traffic cone, etc.), currently ignored
        :param token: Temporally consistent object identifier
        :param coords: Bounding-box coordinates
        """
        self._static_object_tokens.append(token)
        self._static_object_coords.append(coords)

    def _get_nearest_dynamic_objects(
        self, position: Point2D, type: TrackedObjectType
    ) -> Tuple:
        """
        Retrieves nearest k dynamic objects depending on type
        :param position: Ego-vehicle position
        :param type: Object type to sort
        :return: Tuple of tokens, coords, and velocity of nearest objects.
        """
        position_coords = position.array[None, ...]  # shape: (1,2)

        object_tokens = self._dynamic_object_tokens[type]
        object_coords = np.array(self._dynamic_object_coords[type], dtype=np.float64)
        # print(np.array(self._dynamic_object_dxy[type]).shape)
        object_dxy = np.array(self._dynamic_object_dxy[type], dtype=np.float64)

        if len(object_tokens) > 0:
            # add axis if single object found
            if object_coords.ndim == 1:
                object_coords = object_coords[None, ...]
                object_dxy = object_dxy[None, ...]

            position_to_center_dist = (
                (object_coords[..., BBCoordsIndex.CENTER, :] - position_coords) ** 2.0
            ).sum(axis=-1) ** 0.5

            object_argsort = np.argsort(position_to_center_dist)

            object_tokens = [object_tokens[i] for i in object_argsort][
                : MAX_DYNAMIC_OBJECTS[type]
            ]
            object_coords = object_coords[object_argsort][: MAX_DYNAMIC_OBJECTS[type]]
            object_dxy = object_dxy[object_argsort][: MAX_DYNAMIC_OBJECTS[type]]

        return (object_tokens, object_coords, object_dxy)

    def _get_nearest_static_objects(
        self, position: Point2D, type: TrackedObjectType
    ) -> Tuple:
        """
        Retrieves nearest k static obstacles around ego's position.
        :param position: ego's position
        :param type: type of static obstacle (currently ignored)
        :return: tuple of tokens and coords of nearest objects
        """
        position_coords = position.array[None, ...]  # shape: (1,2)

        object_tokens = self._static_object_tokens
        object_coords = np.array(self._static_object_coords, dtype=np.float64)

        if len(object_tokens) > 0:
            # add axis if single object found
            if object_coords.ndim == 1:
                object_coords = object_coords[None, ...]

            position_to_center_dist = (
                (object_coords[..., BBCoordsIndex.CENTER, :] - position_coords) ** 2.0
            ).sum(axis=-1) ** 0.5

            object_argsort = np.argsort(position_to_center_dist)

            object_tokens = [object_tokens[i] for i in object_argsort][
                :MAX_STATIC_OBJECTS
            ]
            object_coords = object_coords[object_argsort][:MAX_STATIC_OBJECTS]

        return (object_tokens, object_coords)
