import numpy as np
import torch


def limit_period(val, offset=0.5, period=np.pi):
    """Limit the value into a period for periodic function.

    Args:
        val (torch.Tensor): The value to be converted.
        offset (float, optional): Offset to set the value range. \
            Defaults to 0.5.
        period ([type], optional): Period of the value. Defaults to np.pi.

    Returns:
        torch.Tensor: Value in the range of \
            [-offset * period, (1-offset) * period]
    """
    return val - torch.floor(val / period + offset) * period


def rotation_3d_in_axis(points, angles, axis=0):
    """Rotate points by angles according to axis.

    Args:
        points (torch.Tensor): Points of shape (N, M, 3).
        angles (torch.Tensor): Vector of angles in shape (N,)
        axis (int, optional): The axis to be rotated. Defaults to 0.

    Raises:
        ValueError: when the axis is not in range [0, 1, 2], it will \
            raise value error.

    Returns:
        torch.Tensor: Rotated points in shape (N, M, 3)
    """
    rot_sin = torch.sin(angles)
    rot_cos = torch.cos(angles)
    ones = torch.ones_like(rot_cos)
    zeros = torch.zeros_like(rot_cos)
    if axis == 1:
        rot_mat_T = torch.stack([
            torch.stack([rot_cos, zeros, -rot_sin]),
            torch.stack([zeros, ones, zeros]),
            torch.stack([rot_sin, zeros, rot_cos])
        ])
    elif axis == 2 or axis == -1:
        rot_mat_T = torch.stack([
            torch.stack([rot_cos, -rot_sin, zeros]),
            torch.stack([rot_sin, rot_cos, zeros]),
            torch.stack([zeros, zeros, ones])
        ])
    elif axis == 0:
        rot_mat_T = torch.stack([
            torch.stack([zeros, rot_cos, -rot_sin]),
            torch.stack([zeros, rot_sin, rot_cos]),
            torch.stack([ones, zeros, zeros])
        ])
    else:
        raise ValueError(f'axis should in range [0, 1, 2], got {axis}')

    return torch.einsum('aij,jka->aik', (points, rot_mat_T))

def det11_to_xyzwhlr(det11):
    """
    Args:
        det11: (np.ndarray): [xy, yaw[:, None], bot[:, None], height[:, None]]

    Returns:
        np.ndarray: Converted boxes in xyzwhlr format.
    """
    xy4 = det11[:,:8].reshape(-1, 4, 2)
    xy = xy4.mean(axis=1)
    z = det11[:, 9, None]
    yaw = det11[:, 8, None]
    # w = np.linalg.norm(xy4[:, 0, :] - xy4[:, 3, :], axis=-1, keepdims=True)
    # l = np.linalg.norm(xy4[:, 0, :] - xy4[:, 1, :], axis=-1, keepdims=True)
    w = ((xy4[:, 0, :] - xy4[:, 3, :]) ** 2).sum(axis=-1,keepdims=True) ** 0.5
    l = ((xy4[:, 0, :] - xy4[:, 1, :]) ** 2).sum(axis=-1,keepdims=True) ** 0.5
    # w = np.ones_like(z)
    # l = np.ones_like(z)
    height = det11[:, -1, None]
    out = np.concatenate([xy, z, w, l, height, yaw], axis=-1)
    return out
    
def xywhr2xyxyr(boxes_xywhr):
    """Convert a rotated boxes in XYWHR format to XYXYR format.

    Args:
        boxes_xywhr (torch.Tensor): Rotated boxes in XYWHR format.

    Returns:
        torch.Tensor: Converted boxes in XYXYR format.
    """
    boxes = torch.zeros_like(boxes_xywhr)
    half_w = boxes_xywhr[:, 2] / 2
    half_h = boxes_xywhr[:, 3] / 2

    boxes[:, 0] = boxes_xywhr[:, 0] - half_w
    boxes[:, 1] = boxes_xywhr[:, 1] - half_h
    boxes[:, 2] = boxes_xywhr[:, 0] + half_w
    boxes[:, 3] = boxes_xywhr[:, 1] + half_h
    boxes[:, 4] = boxes_xywhr[:, 4]
    return boxes


def get_box_type(box_type):
    """Get the type and mode of box structure.

    Args:
        box_type (str): The type of box structure.
            The valid value are "LiDAR", "Camera", or "Depth".

    Returns:
        tuple: Box type and box mode.
    """
    from .box_3d_mode import (Box3DMode, CameraInstance3DBoxes,
                              DepthInstance3DBoxes, LiDARInstance3DBoxes)
    box_type_lower = box_type.lower()
    if box_type_lower == 'lidar':
        box_type_3d = LiDARInstance3DBoxes
        box_mode_3d = Box3DMode.LIDAR
    elif box_type_lower == 'camera':
        box_type_3d = CameraInstance3DBoxes
        box_mode_3d = Box3DMode.CAM
    elif box_type_lower == 'depth':
        box_type_3d = DepthInstance3DBoxes
        box_mode_3d = Box3DMode.DEPTH
    else:
        raise ValueError('Only "box_type" of "camera", "lidar", "depth"'
                         f' are supported, got {box_type}')

    return box_type_3d, box_mode_3d


def points_cam2img(points_3d, proj_mat, with_depth=False):
    """Project points from camera coordicates to image coordinates.

    Args:
        points_3d (torch.Tensor): Points in shape (N, 3).
        proj_mat (torch.Tensor): Transformation matrix between coordinates.
        with_depth (bool, optional): Whether to keep depth in the output.
            Defaults to False.

    Returns:
        torch.Tensor: Points in image coordinates with shape [N, 2].
    """
    points_num = list(points_3d.shape)[:-1]

    points_shape = np.concatenate([points_num, [1]], axis=0).tolist()
    assert len(proj_mat.shape) == 2, 'The dimension of the projection'\
        f' matrix should be 2 instead of {len(proj_mat.shape)}.'
    d1, d2 = proj_mat.shape[:2]
    assert (d1 == 3 and d2 == 3) or (d1 == 3 and d2 == 4) or (
        d1 == 4 and d2 == 4), 'The shape of the projection matrix'\
        f' ({d1}*{d2}) is not supported.'
    if d1 == 3:
        proj_mat_expanded = torch.eye(
            4, device=proj_mat.device, dtype=proj_mat.dtype)
        proj_mat_expanded[:d1, :d2] = proj_mat
        proj_mat = proj_mat_expanded

    # previous implementation use new_zeros, new_one yeilds better results
    points_4 = torch.cat(
        [points_3d, points_3d.new_ones(*points_shape)], dim=-1)
    point_2d = torch.matmul(points_4, proj_mat.t())
    point_2d_res = point_2d[..., :2] / point_2d[..., 2:3]

    if with_depth:
        return torch.cat([point_2d_res, point_2d[..., 2:3]], dim=-1)
    return point_2d_res


def mono_cam_box2vis(cam_box):
    """This is a post-processing function on the bboxes from Mono-3D task. If
    we want to perform projection visualization, we need to:

        1. rotate the box along x-axis for np.pi / 2 (roll)
        2. change orientation from local yaw to global yaw
        3. convert yaw by (np.pi / 2 - yaw)

    After applying this function, we can project and draw it on 2D images.

    Args:
        cam_box (:obj:`CameraInstance3DBoxes`): 3D bbox in camera coordinate \
            system before conversion. Could be gt bbox loaded from dataset or \
                network prediction output.

    Returns:
        :obj:`CameraInstance3DBoxes`: Box after conversion.
    """
    from . import CameraInstance3DBoxes
    assert isinstance(cam_box, CameraInstance3DBoxes), \
        'input bbox should be CameraInstance3DBoxes!'

    loc = cam_box.gravity_center
    dim = cam_box.dims
    yaw = cam_box.yaw
    feats = cam_box.tensor[:, 7:]
    # rotate along x-axis for np.pi / 2
    dim[:, [1, 2]] = dim[:, [2, 1]]
    # change local yaw to global yaw for visualization
    yaw += torch.atan2(loc[:, 0], loc[:, 2])
    # convert yaw by (-yaw - np.pi / 2)
    # this is because mono 3D box class such as `NuScenesBox` has different
    # definition of rotation with our `CameraInstance3DBoxes`
    yaw = -yaw - np.pi / 2
    cam_box = torch.cat([loc, dim, yaw[:, None], feats], dim=1)
    cam_box = CameraInstance3DBoxes(
        cam_box, box_dim=cam_box.shape[-1], origin=(0.5, 0.5, 0.5))

    return cam_box

def waymo2kitti_box(box):

    if isinstance(box, torch.Tensor):
        box = box.clone()
        m = torch
    elif isinstance(box, np.ndarray):
        box = box.copy()
        m = np
    else:
        raise ValueError

    x = box[:, 0]
    y = box[:, 1]
    z = box[:, 2]
    l = box[:, 3]
    w = box[:, 4]
    h = box[:, 5]
    r = box[:, 6]
    kitti_box = m.zeros_like(box)
    kitti_box[:, 0] = x
    kitti_box[:, 1] = y
    kitti_box[:, 2] = z - h/2

    kitti_box[:, 3] = w
    kitti_box[:, 4] = l
    kitti_box[:, 5] = h

    kitti_heading = -r - np.pi / 2
    
    less_mask = kitti_heading < -np.pi
    if less_mask.any():
        kitti_heading[less_mask] += 2 * np.pi

    larger_mask = kitti_heading > np.pi
    if larger_mask.any():
        kitti_heading[larger_mask] -= 2 * np.pi

    kitti_box[:, 6] = kitti_heading
    return kitti_box

def kitti2waymo_box(box):

    if isinstance(box, torch.Tensor):
        box = box.clone()
        m = torch
    elif isinstance(box, np.ndarray):
        box = box.copy()
        m = np
    else:
        raise ValueError

    x = box[:, 0]
    y = box[:, 1]
    z = box[:, 2]
    w = box[:, 3]
    l = box[:, 4]
    h = box[:, 5]
    r = box[:, 6]

    waymo_box = m.zeros_like(box)
    waymo_box[:, 0] = x
    waymo_box[:, 1] = y
    waymo_box[:, 2] = z + h/2

    waymo_box[:, 3] = l
    waymo_box[:, 4] = w
    waymo_box[:, 5] = h

    waymo_heading = -r - np.pi / 2
    
    less_mask = waymo_heading < -np.pi
    if less_mask.any():
        waymo_heading[less_mask] += 2 * np.pi

    larger_mask = waymo_heading > np.pi
    if larger_mask.any():
        waymo_heading[larger_mask] -= 2 * np.pi

    waymo_box[:, 6] = waymo_heading
    return waymo_box