'''
@Author: 
@Email:
@Date: 2020-07-09 13:51:09
LastEditTime: 2021-05-30 19:44:50
@Description: 
    A differentiable renderer implemented with Pytorch.
    In this project, we dont use the gradient information
'''

import numpy as np
import torch

import matplotlib.pyplot as plt
from utils import CUDA, CPU, COLOR, read_obj, rotate_X, rotate_Z, pc_to_rangemap


class Renderer(object):
    """ A differentiable renderer implemented in Pytorch without any trainable parameter. 
    """
    def __init__(self, args):
        # configuration of lidar
        self.upper_fov = args.upper_fov
        self.lower_fov = args.lower_fov
        self.left_fov = args.left_fov
        self.right_fov = args.right_fov
        self.height = args.height
        self.width = args.width
        self.eps = args.eps
        self.max_range = args.max_range  # the max range of lidar
        self.use_background = args.use_background
        self.vehicle_mesh_idx = 0

        # load the ground obj
        self.obj_filename = './data/ground_big.obj'  
        ground_vertices, ground_faces_idx = read_obj(self.obj_filename)
        ground_vertices *= 2
        ground_vertices[:, 0] += 10  # x
        ground_vertices[:, 2] -= 1.9  # z
        self.grounf_F = CUDA(torch.tensor(ground_vertices[ground_faces_idx]))

        if self.vehicle_mesh_idx == 0:
            # load vehicle obj model
            self.obj_filename = './data/car6.obj'  
            vertices, faces_idx = read_obj(self.obj_filename)
            # NOTE: the initial transform of the object is case-by-case.
            self.object_scale = 16
            vertices = vertices/self.object_scale
            # NOTE: the z offset is determined by kitti coordinate
            vertices[:, 2] -= 1.7
            self.half_vehicle_length = 3
        elif self.vehicle_mesh_idx == 1:
            # the original dimension: x： 0.42, y: 0.31, z: 0.85
            # average car length is 4.5m, therefore the scale is 4.5/0.85 = 5.3
            self.obj_filename = './data/car3.obj'
            vertices, faces_idx = read_obj(self.obj_filename)
            # NOTE: the initial transform of the object is case-by-case.
            self.object_scale = 5.3
            vertices = vertices * self.object_scale
            vertices = vertices.dot(rotate_X(90))
            vertices = vertices.dot(rotate_Z(90))
            # NOTE: the z offset is determined by kitti coordinate
            vertices[:, 2] -= 1.22
            self.half_vehicle_length = 3

        # for estimating the worst-case bounding box of the vehicle
        self.z_min = np.min(vertices[:, 2])-1.0
        self.z_max = np.max(vertices[:, 2])+1.0

        self.original_V = CUDA(torch.from_numpy(vertices).float())
        self.faces_idx = CUDA(torch.LongTensor(faces_idx))

        # generate lidar rays
        # NOTE: increasing M will increase the spped but consume more memory
        rays = self._precompute_lidar_ray()
        self.rays = CUDA(torch.from_numpy(rays).float())

        # pre-compute the shuffle index to make the inference of target model stable
        self.shuffle_idx = torch.randperm(self.width*self.height)

        print(COLOR.GREEN+'Render Info:')
        print('\tVehicle object vertice number:', vertices.shape[0])
        print('\tVehicle object face number:', self.faces_idx.shape[0])
        print('\tLidar channel:', self.height)
        print('\tLidar beam number:', self.width)
        print('\tLidar max range:', args.max_range)
        print('\tLidar lower fov:', self.lower_fov)
        print('\tLidar upper fov:', self.upper_fov)
        print('\tLidar lower fov:', self.left_fov)
        print('\tLidar upper fov:', self.right_fov)
        print(COLOR.WHITE+'')

    def raycast(self, code, background_w_label=None):
        """ Batched ray-casting for given transformations.

        Args:
            code: the physical parameters of the object
            background: [2048, 64], use a background range map if not None
        Return:
            xyz: pointcloud
            label: point-wise label. If background is None, the label is invalid.
        """
        # compute the forground
        foreground = self._compute_foreground(code)

        # calculate the background from a mesh model
        if background_w_label is None and self.use_background:
            background = self._compute_background()
            background_label = None

        if background_w_label is not None:
            background = background_w_label[0]
            background_label = background_w_label[1]
            # directly add the background pointcloud
            mixed_rangemap, label_mask = self._mixed_foreground_background(foreground, background, background_label)

            # convert the distance t to absolute coordinates and only retain the vehicle part
            xyz = self._rangemap_to_pc(mixed_rangemap)
            # Note that the label should be detached from the graph
            label = label_mask.reshape(-1, 1).float()
        else:
            xyz = self._rangemap_to_pc(foreground)
            label = CUDA(torch.zeros(xyz.shape[0], 1))

        # shuffle the pc and delete the zero-distance points. Iteration-wise shuffle will cause a large variation of inference result
        self.xyz, label = self._postprocess(xyz, label)
        return self.xyz, label.view(-1) # [S, 3], # [S,]

    def _pc_to_rangemap(self, points):
        # points - [x, y, z, r, g, b]
        fov_horizon_range = self.right_fov - self.left_fov
        range_image, label = pc_to_rangemap(points, fov_horizon_range, self.lower_fov, self.upper_fov, self.width, self.height, self.max_range)
        range_image = CUDA(torch.from_numpy(range_image)).T
        label = CUDA(torch.from_numpy(label)).T
        return (range_image, label) # [2048, 64], [2048, 64]

    def _compute_background(self):
        F = self.grounf_F
        N = F.shape[0]

        # get the valid laser directions for this object, this will speed up the rendering
        crop_direction = self._precompute_valid_region(None, full=True)

        # divide the directions into batches
        D_list, M = self._batched_directions(crop_direction)
        assert len(D_list) > 0, 'There is no valid laser direction'

        # compute the triangle matrix 
        self._precompute_triangle(F, M)

        # ray casting with batch
        all_t = []
        all_mask = []
        all_distance = []
        for D in D_list:
            one_t, one_mask, one_dist = self._MT_raycast_batch(D, N) 
            all_t.append(one_t[None])
            all_mask.append(one_mask[None])
            all_distance.append(one_dist[None])
        all_t = torch.cat(all_t, axis=0)        # [Q, M, N]
        all_mask = torch.cat(all_mask, axis=0)  # [Q, M, N]
        all_distance = torch.cat(all_distance, axis=0)  # [Q, M, N]

        # get the final foreground with compensated gradient
        background = self._gradient_compensation(all_t, all_mask, all_distance, N)
        return background

    def _compute_foreground(self, code):
        """ Computer the range map of the foreground objects.

        Args:
            code: the physical parameters of objects.
        Return:
            foreground_min: the closest depth of all objects
        """
        # apply transformation
        # code is parsed from a tree structure, which is a matrix. Each row represent an object
        # we process each object and find the projection region separetely. Then concatenate all range maps in the end
        foreground_list = []
        for t_i in range(code.shape[0]):
            # 3d transformation of vertice and build faces with index
            V = self.original_V.matmul(self.Rz(code[t_i][0])) + self.Translation(code[t_i][1], code[t_i][2])
            #V = self.original_V + self.Translation(code[t_i][1], code[t_i][2])
            object_F = V[self.faces_idx]

            # add the background faces
            #F = torch.cat([object_F, self.grounf_F], dim=0)
            F = object_F
            N = F.shape[0]

            # get the valid laser directions for this object, this will speed up the rendering
            crop_direction = self._precompute_valid_region(code[t_i].detach(), full=False)
            #crop_direction = self._precompute_valid_region_2(F.detach())

            # divide the directions into batches
            D_list, M = self._batched_directions(crop_direction)
            assert len(D_list) > 0, 'There is no valid laser direction'

            # back-back culling with D_list
            #F = self._backface_culling(D_list, F)

            # compute the triangle matrix 
            self._precompute_triangle(F, M)

            # ray casting with batch
            all_t = []
            all_mask = []
            all_distance = []
            for D in D_list:
                one_t, one_mask, one_dist = self._MT_raycast_batch(D, N) 
                all_t.append(one_t[None])
                all_mask.append(one_mask[None])
                all_distance.append(one_dist[None])
            all_t = torch.cat(all_t, axis=0)        # [Q, M, N]
            all_mask = torch.cat(all_mask, axis=0)  # [Q, M, N]
            all_distance = torch.cat(all_distance, axis=0)  # [Q, M, N]

            # get the final foreground with compensated gradient
            foreground = self._gradient_compensation(all_t, all_mask, all_distance, N)
            empty_foreground = self.max_range*CUDA(torch.ones(self.width, self.height))
            if self.two_piece_flag:
                # colId_left_1 is always 0, therefore we can use colId_right_1 to separate
                empty_foreground[self.colId_left_1:self.colId_right_1, self.rowId_lower:self.rowId_upper] = foreground[:self.colId_right_1, :]
                empty_foreground[self.colId_left_2:self.colId_right_2, self.rowId_lower:self.rowId_upper] = foreground[self.colId_right_1:, :]
            else:
                empty_foreground[self.colId_left:self.colId_right, self.rowId_lower:self.rowId_upper] = foreground
            foreground_list.append(empty_foreground)

        # calculate the minimal value for all foregrounds
        foreground_list_ = torch.stack(foreground_list, dim=0)
        foreground_min, _ = torch.min(foreground_list_, dim=0, keepdim=False)
        return foreground_min # [2048, 64]

    def _gradient_compensation(self, all_t, all_mask, all_distance, N):
        # get the exact intersection with a hard mask (invalid point is 0 or max_range)
        t_valid = all_t * all_mask + self.max_range * torch.logical_not(all_mask) # [Q, M, N]  
        #t_valid = all_t * all_mask
        foreground, t_min_idx = torch.min(t_valid, dim=2, keepdim=False)  # [Q, M], [Q, M]

        # dont use any compensation
        return foreground.T

    def _MT_raycast_batch(self, D, N):
        """ This is an implementation of the renderer in CVPR2020 paper 'Physically Realizable Adversarial Examples for LiDAR Object Detection'.
            This is a vectorized version of MT algorithms that can run on GPU with auto-grad libraries like Pytorch and Tensorflow.
            The origins of rays are not required since they are all zeros. We ignore them to increase efficiency.

        Args:
            D: [M, 3]: the direction of the ray, should be a normalized one
            N: number of faces
        Returns:
            t_min: [M, 1]: the distance between intersection point and origin (0, 0, 0)
            xyz: [M, 1]: the 3D coordinates of the points
        """

        # [M, 3] DOT [3, N]
        D_ = D.unsqueeze(1).repeat(1, N, 1)          # [M, N, 3]
        D_d_E1_c_E2 = torch.matmul(D, self.E1_c_E2.T)         # [M, N]
        #D_d_E1_c_E2 = torch.sum(D_*self.E1_c_E2_, dim=2)  # [M, N]
        D_c_E2 = torch.cross(D_, self.E2_, dim=2)         # [M, N, 3], no grad because only used in u
        E1_c_D = torch.cross(self.E1_, D_, dim=2)         # [M, N, 3], no grad because only used in v

        # a simple way to do batch dot product, time efficient but storage ineifficient
        # NOTE: divide D_d_E1_c_E2 will cause some numerical problem when they are too small
        # NOTE: D_d_E1_c_E2 could be negative, we cannot directly move it to the RHS of u < 1 and v < 1
        A_d_E1_C_E2 = torch.sum(self.A_*self.E1_c_E2_, dim=2)
        t = A_d_E1_C_E2/D_d_E1_c_E2     # [M, N]
        u = torch.sum(D_c_E2*self.A_, dim=2)/D_d_E1_c_E2            # [M, N], no grad
        v = torch.sum(self.A_*E1_c_D, dim=2)/D_d_E1_c_E2            # [M, N], no grad
        w = 1.0 - u - v

        # check condition: 0.0 < u < 1.0, 0.0 < v < 1.0, 0.0 < w < 1.0 (w = 1 - u - v)
        # t should be positive since it represent the distance
        u_mask = u.ge(0.0) & u.le(1.0)
        v_mask = v.ge(0.0) & v.le(1.0)
        w_mask = w.ge(0.0) & w.le(1.0) 
        t_mask = t.ge(0.0)
        #mask = u_mask & v_mask & w_mask & t_mask & filter_mask
        mask = u_mask & v_mask & w_mask & t_mask  # [M, N]

        # calculate weights according to the barycentric coordinate
        dist, _ = torch.min(torch.stack([u, v, w]), dim=0) # the shorest distance from p_i to f_j
        #dist = torch.sqrt(u**2 + v**2 + w**2)
        return t, mask, dist

    def _mixed_foreground_background(self, foreground, background, background_label):
        # select the minimal value for each pixel
        mixed_rangemap = torch.cat([foreground[None], background[None]], dim=0)
        mixed_rangemap_, _ = torch.min(mixed_rangemap, dim=0, keepdim=False)

        # when foreground is close than background, the pixel is the object
        label_mask = foreground < background
        # the max range points should also be deleted from the label mask
        label_mask[foreground == self.max_range] = 0

        # combine the label from the background
        if background_label is not None:
            label_mask = torch.logical_or(label_mask, torch.logical_and(torch.logical_not(label_mask), background_label))
        return mixed_rangemap_, label_mask

    def _rangemap_to_pc(self, rangemap):
        """ Use range map and ray directions to recover the 3D points.
        """
        xyz = (rangemap.unsqueeze(2)*self.rays).contiguous().view((-1, 3))
        return xyz

    def _compute_vertical_angle(self, xy_dis, z):
        ang_res_y = (self.upper_fov-self.lower_fov)/float(self.height-1) # vertical resolution
        vertical_angle = np.rad2deg(np.arctan2(z, xy_dis))
        relative_vertical_angle = vertical_angle - self.lower_fov
        # NOTE: the x-axis of range map is the opposite of z-axis of pointcloud
        rowId = self.height - np.int_(np.round(relative_vertical_angle / ang_res_y))
        return rowId

    def _compute_horizontal_angle(self, x, y):
        horizontal_angle = np.arctan2(y, x)
        horizontal_angle += np.pi  # the range of arctan2 is [-pi, pi], we want to convert to [0, 2*pi]
        colId = np.int_(horizontal_angle*self.width/(2*np.pi))
        return colId

    def _precompute_valid_region(self, code, full=True):
        """ This function is only for decreasing the memory, the whole pipeline works without it.
            This function should not influence the gradient graph.
            Basically, we use a cude (l * l * z_max-z_min) to crop the ray, where l is the half of the length of the vehicle.
            This is the worst case and the best we can do for an rigid body.
        
        Args:
            code: the physical parameters of one object
        Returns:
            crop_direction: the valid laser directions.
        """
        if not full:
            x = CPU(code[1])
            y = CPU(code[2])

            ''' horizontal crop '''
            # NOTE: when we create the ray in _precompute_lidar_ray, the mapping is colId=2048 -> theta=0.
            #       Also, when we create the range map of background (utils), the mapping is the same.
            #       Here, we should keep the same setting to make all three mapping exactly the same.
            # NOTE: detach from the graph to avoid gradient accumulation
            perpendicular_vec = np.array([y, -x])/((x**2+y**2)**0.5)
            left_xy = np.array([x, y]) + perpendicular_vec*self.half_vehicle_length
            right_xy = np.array([x, y]) - perpendicular_vec*self.half_vehicle_length
            self.colId_left = self._compute_horizontal_angle(left_xy[0], left_xy[1])
            self.colId_right = self._compute_horizontal_angle(right_xy[0], right_xy[1])

            # if the worst-case box is divided into two pieces, we should put them together
            if self.colId_right < self.colId_left:
                self.two_piece_flag = True
                self.colId_left_1 = 0
                self.colId_right_1 = self.colId_right
                self.colId_left_2 = self.colId_left
                self.colId_right_2 = self.width
            else:
                self.two_piece_flag = False

            ''' vertical crop '''
            # add a offset of the center point, the offset is the worst case (half of the length of the vehicle)
            xy_dis = np.sqrt(x**2 + y**2) - self.half_vehicle_length
            rowId_upper = self._compute_vertical_angle(xy_dis, self.z_min)
            rowId_lower = self._compute_vertical_angle(xy_dis, self.z_max)

            self.rowId_lower = 0 if rowId_lower < 0 else rowId_lower
            self.rowId_upper = self.height if rowId_upper > self.height else rowId_upper
            # when the rowId_lower is larger than rowId_upper, this is an invalid situation
            if self.rowId_upper <= self.rowId_lower:
                self.rowId_lower = 0
                self.rowId_upper = 64

            ''' divide the rays into batches in advance '''
            # convert from [W, H, 3] to the shape [H*W, 3]
            if self.two_piece_flag:
                col_idx = np.r_[self.colId_left_1:self.colId_right_1, self.colId_left_2:self.colId_right_2]
            else:
                col_idx = np.r_[self.colId_left:self.colId_right]
            crop_rays = self.rays[col_idx, self.rowId_lower:self.rowId_upper, :]
        else:
            # use the whole area to do rendering
            crop_rays = self.rays
            self.rowId_lower = 0
            self.rowId_upper = 64 
            self.colId_left = 0
            self.colId_right = 2048
            self.two_piece_flag = False

        return crop_rays

    def _batched_directions(self, crop_direction):
        """ This function divides the valid directions into batches
        Args:
            crop_direction:
        Returns:
            D_list: the batched directions
            M: the number of laser in one batch. it is not necessary to be equal to the first dimension of directions (K).
        """
        # each batch contains one laser scan
        D_list = [crop_direction[:, l_i, :] for l_i in range(crop_direction.shape[1])]
        M = crop_direction.shape[0]
        return D_list, M

    def _backface_culling(self, D_list, T):
        """ Remove the back faces that will not have chances to intersect with lasers
        """
        D = torch.cat(D_list, dim=0)
        A = T[:, 0, :]
        B = T[:, 1, :] 
        C = T[:, 2, :]
        E1 = B - A                                            # [N*Q, 3]
        E2 = C - A                                            # [N*Q, 3]
        E1_c_E2 = torch.cross(E1, E2, dim=1).transpose(1, 0)  # [3, N]
        D_d_E1_c_E2 = torch.matmul(D, E1_c_E2)                # [M*Q, N]

        # if all directions think one face is a back face, then it is
        max_dot, _ = torch.max(D_d_E1_c_E2, dim=0)   # [N]
        frontface_mask = max_dot < 0.0
        T = T[frontface_mask]
        self.N = T.shape[0]  
        return T

    def _precompute_triangle(self, F, M):
        """ Pre-compute some variables to increase efficiency.

        Args: 
            F: [N, 3, 3]: the second dim is the three points of a triangle, the third dim is the xyz coordinate
            M: the number of laser in one batch. it is not necessary to be equal to the first dimension of directions.
        Returns:
            None
        """

        A = F[:, 0, :]                                            # [N, 3]
        B = F[:, 1, :]                                            # [N, 3]
        C = F[:, 2, :]                                            # [N, 3]
        E1 = B - A                                                # [N, 3]
        E2 = C - A                                                # [N, 3]
        self.E1_c_E2 = torch.cross(E1, E2, dim=1)                 # [N, 3]
        self.A_ = A.unsqueeze(0).repeat(M, 1, 1)                  # [M, N, 3]                         
        self.E1_ = E1.unsqueeze(0).repeat(M, 1, 1)                # [M, N, 3], no grad because only used in v
        self.E2_ = E2.unsqueeze(0).repeat(M, 1, 1)                # [M, N, 3], no grad because only used in u
        self.E1_c_E2_ = self.E1_c_E2.unsqueeze(0).repeat(M, 1, 1) # [M, N, 3]

    def _precompute_lidar_ray(self):
        """ Generate all rays that we need and divide them into groups.
        """
        horizontal_angle = np.linspace(np.deg2rad(self.left_fov), np.deg2rad(self.right_fov), self.width)
        vertical_angle = np.linspace(np.deg2rad(self.upper_fov), np.deg2rad(self.lower_fov), self.height)

        # get the coordinates of each beam
        z = np.repeat(np.sin(vertical_angle)[None], self.width, axis=0)
        dist = np.cos(vertical_angle)
        x = np.cos(horizontal_angle)[:, None] * dist
        y = np.sin(horizontal_angle)[:, None] * dist

        # concate x, y and z. No need to normalize the directions since they are normalized
        rays = np.concatenate([x[:, :, None], y[:, :, None], z[:, :, None]], axis=2)  # [W, H, 3]
        return rays 

    def _postprocess(self, xyz, label):
        xyzl = torch.cat([xyz, label], dim=1)

        # shuffle the generated pointcloud
        xyzl = xyzl[self.shuffle_idx, :]

        # delete points that are larger than max_range, a tolernce is used to counter numerical problem
        # delete zero distance points
        range_xyz = (xyzl[:, 0]**2 + xyzl[:, 1]**2 + xyzl[:, 2]**2)**0.5
        valid_range_idx = (range_xyz < (self.max_range - 1)) & (range_xyz > 1.0)
        xyzl = xyzl[valid_range_idx]

        # divide the pointcloud and the label
        xyz = xyzl[:, 0:3]
        label = xyzl[:, 3]
        return xyz, label

    @staticmethod
    def Rx(rad):
        # mod 
        rad = rad % np.pi
        
        # NOTE: we should not directly define new tensors, otherwise the gradient flow will be broke down
        # NOTE: torch == 1.6 starts to support deg2rad
        #rad = torch.deg2rad(deg)
        rotation = CUDA(torch.eye(3))
        rotation[1, 1] = torch.cos(rad)
        rotation[1, 2] = torch.sin(rad)
        rotation[2, 1] = -torch.sin(rad)
        rotation[2, 2] = torch.cos(rad)
        return rotation

    @staticmethod
    def Rz(rad):
        # mod
        rad = rad % np.pi

        # NOTE: we should not directly define new tensors, otherwise the gradient flow will be broke down
        # NOTE: torch == 1.6 starts to support deg2rad
        #rad = torch.deg2rad(deg)
        rotation = CUDA(torch.eye(3))
        rotation[0, 0] = torch.cos(rad)
        rotation[0, 1] = torch.sin(rad)
        rotation[1, 0] = -torch.sin(rad)
        rotation[1, 1] = torch.cos(rad)
        return rotation

    @staticmethod
    def Translation(x, y):
        # NOTE: we should not directly define new tensors, otherwise the gradient flow will be broke down
        translation = torch.stack([x, y, CUDA(torch.tensor(0.0))])
        return translation
