#---------------------------------------------------------------------------------#
# UniAD: Planning-oriented Autonomous Driving (https://arxiv.org/abs/2212.10156)  #
# Source code: https://github.com/OpenDriveLab/UniAD                              #
# Copyright (c) OpenDriveLab. All rights reserved.                                #
#---------------------------------------------------------------------------------#

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import numpy.typing as npt
from casadi import DM, Opti, OptiSol, cos, diff, sin, sumsqr, vertcat, exp
from agentdriver.utils.det2occ import det2occ
from agentdriver.utils.geometry import CAR_WIDTH, CAR_LENGTH


Pose = Tuple[float, float, float]  # (x, y, yaw)


class CollisionNonlinearOptimizer:
    """
    Optimize planned trajectory with predicted occupancy
    Solved with direct multiple-shooting.
    modified from https://github.com/motional/nuplan-devkit
    :param trajectory_len: trajectory length
    :param dt: timestep (sec)
    """

    def __init__(self, trajectory_len: int, dt: float, sigma, alpha_collision, obj_pixel_pos):
        """
        :param trajectory_len: the length of trajectory to be optimized.
        :param dt: the time interval between trajectory points.
        """
        self.dt = dt
        self.trajectory_len = trajectory_len
        self.current_index = 0
        self.sigma = sigma
        self.alpha_collision = alpha_collision
        self.obj_pixel_pos = obj_pixel_pos
        # Use a array of dts to make it compatible to situations with varying dts across different time steps.
        self._dts: npt.NDArray[np.float32] = np.asarray([[dt] * trajectory_len])
        self._init_optimization()

    def _init_optimization(self) -> None:
        """
        Initialize related variables and constraints for optimization.
        """
        self.nx = 2  # state dim

        self._optimizer = Opti()  # Optimization problem
        self._create_decision_variables()
        self._create_parameters()
        self._set_objective()

        # Set default solver options (quiet)
        self._optimizer.solver("ipopt", {"ipopt.print_level": 0, "print_time": 0, "ipopt.sb": "yes"})

    def set_reference_trajectory(self, reference_trajectory: Sequence[Pose]) -> None:
        """
        Set the reference trajectory that the smoother is trying to loosely track.
        :param x_curr: current state of size nx (x, y)
        :param reference_trajectory: N x 3 reference, where the second dim is for (x, y)
        """
        self._optimizer.set_value(self.ref_traj, DM(reference_trajectory).T)
        self._set_initial_guess(reference_trajectory)

    def set_solver_optimizerons(self, options: Dict[str, Any]) -> None:
        """
        Control solver options including verbosity.
        :param options: Dictionary containing optimization criterias
        """
        self._optimizer.solver("ipopt", options)

    def solve(self) -> OptiSol:
        """
        Solve the optimization problem. Assumes the reference trajectory was already set.
        :return Casadi optimization class
        """
        return self._optimizer.solve()

    def _create_decision_variables(self) -> None:
        """
        Define the decision variables for the trajectory optimization.
        """
        # State trajectory (x, y)
        self.state = self._optimizer.variable(self.nx, self.trajectory_len)
        self.position_x = self.state[0, :]
        self.position_y = self.state[1, :]

    def _create_parameters(self) -> None:
        """
        Define the expert trjactory and current position for the trajectory optimizaiton.
        """
        self.ref_traj = self._optimizer.parameter(2, self.trajectory_len)  # (x, y)

    def _set_objective(self) -> None:
        """Set the objective function. Use care when modifying these weights."""
        # Follow reference, minimize control rates and absolute inputs
        alpha_xy = 1.0
        cost_stage = (
            alpha_xy * sumsqr(self.ref_traj[:2, :] - vertcat(self.position_x, self.position_y))
        )

        alpha_collision = self.alpha_collision
        
        cost_collision = 0
        normalizer = 1/(2.507*self.sigma)
        # TODO: vectorize this
        for t in range(len(self.obj_pixel_pos)):
            x, y = self.position_x[t], self.position_y[t]
            for i in range(len(self.obj_pixel_pos[t])):
                col_x, col_y = self.obj_pixel_pos[t][i]
                cost_collision += alpha_collision * normalizer * exp(-((x - col_x)**2 + (y - col_y)**2)/2/self.sigma**2)
        self._optimizer.minimize(cost_stage + cost_collision)

    def _set_initial_guess(self, reference_trajectory: Sequence[Pose]) -> None:
        """Set a warm-start for the solver based on the reference trajectory."""
        # Initialize state guess based on reference
        self._optimizer.set_initial(self.state[:2, :], DM(reference_trajectory).T)  # (x, y, yaw)

debug = False
def collision_optimization(sdc_traj_all, data_dict, occ_filter_range=5.0, sigma=1.0, alpha_collision=5.0):
        """
        Optimize SDC trajectory with occupancy instance mask.

        Args:
            sdc_traj_all (torch.Tensor): SDC trajectory tensor.
            occ_mask (torch.Tensor): Occupancy flow instance mask. 
        Returns:
            torch.Tensor: Optimized SDC trajectory tensor.
        """
        planning_steps = 6
        bev_h = 200
        bev_w = 200

        pos_xy_t = []
        valid_occupancy_num = 0
        
        # ## adoption for uniad
        occ_mask = data_dict["occupancy"]

        # Try susbtitute occupancy with object detection
        # occ_mask = det2occ(data_dict)
        # occ_mask = np.concatenate((occ_mask, occ_mask_temp[-2:]), axis=0)
        
        # occ_mask = np.fliplr(data_dict["gt_occ"][0][1:].cpu().numpy().transpose(1,2,0)).transpose(2,0,1)
        occ_mask = occ_mask[np.newaxis, :]
        occ_mask = np.where(occ_mask > 0.1, True, False) # NOTE: the threshold of 0.1 is also used in uniad
        
        
        if occ_mask.shape[2] == 1:
            occ_mask = occ_mask.squeeze(2)
        occ_horizon = occ_mask.shape[1]
        # assert occ_horizon == 5

        for t in range(planning_steps):
            cur_t = min(t+1, occ_horizon-1)
            pos_xy = np.nonzero(occ_mask[0][cur_t])
            pos_xy = np.stack(pos_xy, axis=-1)[:, [1, 0]].astype(np.float32)
            pos_xy[:, 0] = (pos_xy[:, 0] - bev_h//2) * 0.5 + 0.25 # TODO use geometry.py
            pos_xy[:, 1] = (pos_xy[:, 1] - bev_w//2) * 0.5 + 0.25 # TODO use geometry.py # TODO check if we should add 1 here

            if debug and t == 0:
                import matplotlib.pyplot as plt
                from agentdriver.utils.geometry import rotate_bbox
                plt.figure()
                
                plt.scatter(pos_xy[:, 0], pos_xy[:, 1], c='r', s=10)
                for obj in data_dict["objects"]:
                    x, y, z, dx, dy, dz, rotation_z, rotation_y, rotation_x = obj["bbox"]
                    cx, cy = x, y
                    rotated_corners = rotate_bbox(cx, cy, dx, dy, rotation_z)
                    for pt in rotated_corners:
                        plt.scatter(pt[0], pt[1], c='b', s=10)
                plt.scatter(sdc_traj_all[:, 0], sdc_traj_all[:, 1], c='g', s=10)

            # filter the occupancy in range
            keep_index = np.sum((sdc_traj_all[t, :2][None, :] - pos_xy[:, :2])**2, axis=-1) < occ_filter_range**2
            pos_xy_t.append(pos_xy[keep_index])
            valid_occupancy_num += np.sum(keep_index>0)
        if valid_occupancy_num == 0:
            return sdc_traj_all
        
        col_optimizer = CollisionNonlinearOptimizer(planning_steps, 0.5, sigma, alpha_collision, pos_xy_t)
        col_optimizer.set_reference_trajectory(sdc_traj_all)
        sol = col_optimizer.solve()
        sdc_traj_optim = np.stack([sol.value(col_optimizer.position_x), sol.value(col_optimizer.position_y)], axis=-1)

        if debug:
            plt.scatter(sdc_traj_optim[:, 0], sdc_traj_optim[:, 1], c='y', s=10)
            

        return sdc_traj_optim