import torch
import roipool3d_cuda
import numpy as np
import lib.utils.kitti_utils as kitti_utils


def roipool3d_gpu(pts, pts_feature, boxes3d, pool_extra_width, sampled_pt_num=512):
    """
    :param pts: (B, N, 3)
    :param pts_feature: (B, N, C)
    :param boxes3d: (B, M, 7)
    :param pool_extra_width: float
    :param sampled_pt_num: int
    :return:
        pooled_features: (B, M, 512, 3 + C)
        pooled_empty_flag: (B, M)
    """
    batch_size, boxes_num, feature_len = pts.shape[0], boxes3d.shape[1], pts_feature.shape[2]
    pooled_boxes3d = kitti_utils.enlarge_box3d(boxes3d.view(-1, 7), pool_extra_width).view(batch_size, -1, 7)

    pooled_features = torch.cuda.FloatTensor(torch.Size((batch_size, boxes_num,
                                                         sampled_pt_num, 3 + feature_len))).zero_()
    pooled_empty_flag = torch.cuda.IntTensor(torch.Size((batch_size, boxes_num))).zero_()

    roipool3d_cuda.forward(pts.contiguous(), pooled_boxes3d.contiguous(),
                           pts_feature.contiguous(), pooled_features, pooled_empty_flag)

    return pooled_features, pooled_empty_flag


def pts_in_boxes3d_cpu(pts, boxes3d):
    """
    :param pts: (N, 3) in rect-camera coords
    :param boxes3d: (M, 7)
    :return: boxes_pts_mask_list: (M), list with [(N), (N), ..]
    """
    if not pts.is_cuda:
        pts = pts.float().contiguous()
        boxes3d = boxes3d.float().contiguous()
        pts_flag = torch.LongTensor(torch.Size((boxes3d.size(0), pts.size(0))))  # (M, N)
        roipool3d_cuda.pts_in_boxes3d_cpu(pts_flag, pts, boxes3d)

        boxes_pts_mask_list = []
        for k in range(0, boxes3d.shape[0]):
            cur_mask = pts_flag[k] > 0
            boxes_pts_mask_list.append(cur_mask)
        return boxes_pts_mask_list
    else:
        raise NotImplementedError


def roipool_pc_cpu(pts, pts_feature, boxes3d, sampled_pt_num):
    """
    :param pts: (N, 3)
    :param pts_feature: (N, C)
    :param boxes3d: (M, 7)
    :param sampled_pt_num: int
    :return:
    """
    pts = pts.cpu().float().contiguous()
    pts_feature = pts_feature.cpu().float().contiguous()
    boxes3d = boxes3d.cpu().float().contiguous()
    assert pts.shape[0] == pts_feature.shape[0] and pts.shape[1] == 3, '%s %s' % (pts.shape, pts_feature.shape)
    assert pts.is_cuda is False
    pooled_pts = torch.FloatTensor(torch.Size((boxes3d.shape[0], sampled_pt_num, 3))).zero_()
    pooled_features = torch.FloatTensor(torch.Size((boxes3d.shape[0], sampled_pt_num, pts_feature.shape[1]))).zero_()
    pooled_empty_flag = torch.LongTensor(boxes3d.shape[0]).zero_()
    roipool3d_cuda.roipool3d_cpu(pts, boxes3d, pts_feature, pooled_pts, pooled_features, pooled_empty_flag)
    return pooled_pts, pooled_features, pooled_empty_flag


def roipool3d_cpu(boxes3d, pts, pts_feature, pts_extra_input, pool_extra_width, sampled_pt_num=512,
                  canonical_transform=True):
    """
    :param boxes3d: (N, 7)
    :param pts: (N, 3)
    :param pts_feature: (N, C)
    :param pts_extra_input: (N, C2)
    :param pool_extra_width: constant
    :param sampled_pt_num: constant
    :return:
    """
    pooled_boxes3d = kitti_utils.enlarge_box3d(boxes3d, pool_extra_width)

    pts_feature_all = np.concatenate((pts_extra_input, pts_feature), axis=1)

    #  Note: if pooled_empty_flag[i] > 0, the pooled_pts[i], pooled_features[i] will be zero
    pooled_pts, pooled_features, pooled_empty_flag = \
        roipool_pc_cpu(torch.from_numpy(pts), torch.from_numpy(pts_feature_all),
                       torch.from_numpy(pooled_boxes3d), sampled_pt_num)

    extra_input_len = pts_extra_input.shape[1]
    sampled_pts_input = torch.cat((pooled_pts, pooled_features[:, :, 0:extra_input_len]), dim=2).numpy()
    sampled_pts_feature = pooled_features[:, :, extra_input_len:].numpy()

    if canonical_transform:
        # Translate to the roi coordinates
        roi_ry = boxes3d[:, 6] % (2 * np.pi)  # 0~2pi
        roi_center = boxes3d[:, 0:3]

        # shift to center
        sampled_pts_input[:, :, 0:3] = sampled_pts_input[:, :, 0:3] - roi_center[:, np.newaxis, :]
        for k in range(sampled_pts_input.shape[0]):
            sampled_pts_input[k] = kitti_utils.rotate_pc_along_y(sampled_pts_input[k], roi_ry[k])

        return sampled_pts_input, sampled_pts_feature

    return sampled_pts_input, sampled_pts_feature, pooled_empty_flag.numpy()


if __name__ == '__main__':
    pass
