# Copyright (c) Facebook, Inc. and its affiliates.
# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
"""
Various positional encodings for the transformer.
"""
import math

import torch
from torch import nn


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x, mask=None):
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack(
            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos_y = torch.stack(
            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos
    
    def __repr__(self, _repr_indent=4):
        head = "Positional encoding " + self.__class__.__name__
        body = [
            "num_pos_feats: {}".format(self.num_pos_feats),
            "temperature: {}".format(self.temperature),
            "normalize: {}".format(self.normalize),
            "scale: {}".format(self.scale),
        ]
        # _repr_indent = 4
        lines = [head] + [" " * _repr_indent + line for line in body]
        return "\n".join(lines)


class SplitPositionEmbedding(nn.Module):
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        
        self.num_pos_feats = num_pos_feats/2
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale
        
        self.x_limit = 40
        self.y_limit = 50
        
        # img_size is [w, h]
        # self.img_size = img_size
        
    def forward(self, x, img_size, calib=None, bev=False, abs_bev=True):
        # img_size: [w, h]
        # calib is resized to meet the img_size
        if bev:
            range_x = torch.arange(img_size[0]).cuda()
            range_y = torch.arange(img_size[1]).cuda()
            cur_y, cur_x = torch.meshgrid(range_y,range_x)
            cam_height = 1.7
         
            f = calib[0,0]
            y_center = calib[1,-1]
            y_embed = cam_height*f/(cur_y - y_center + 0.1)
            x_embed = (y_embed*cur_x - calib[0,-1]*y_embed)/f
            
            to_remove = (y_embed < 0) | (y_embed > self.y_limit)
            
            x_embed[y_embed < 0] = 0
            x_embed[y_embed > self.y_limit] = 0
            
            y_embed[y_embed < 0] = 0
            y_embed[y_embed > self.y_limit] = 0
         
            if abs_bev:         
                y_embed = y_embed.clamp(-self.y_limit,self.y_limit)/self.y_limit + 2
                x_embed = x_embed.clamp(-self.x_limit,self.x_limit)/self.x_limit + 2
                
                x_embed[to_remove] = 1
                x_embed[to_remove] = 1
                
                y_embed[to_remove] = 1
                y_embed[to_remove] = 1
                
                y_embed = torch.flip(y_embed,dims=[0])
                
                    
                x_embed = torch.log(x_embed)
                
                y_embed = torch.log(y_embed)
                
                y_embed = y_embed.unsqueeze(0).cumsum(1, dtype=torch.float32) 
                x_embed = x_embed.unsqueeze(0).cumsum(2, dtype=torch.float32)
                
                eps = 1e-6
                y_embed = torch.flip(y_embed,dims=[1])
                y_embed = y_embed / (y_embed[:,:1, :] + eps) 
                x_embed = x_embed / (x_embed[:,:, -1:] + eps) 
                
                x_embed[0,to_remove] = 1
                y_embed[0,to_remove] = 1
                
                x_embed = x_embed * self.scale
                y_embed = y_embed * self.scale

            else:               
                y_embed = y_embed.clamp(-self.y_limit,self.y_limit)/self.y_limit + 1
                x_embed = x_embed.clamp(-self.x_limit,self.x_limit)/self.x_limit + 1
                
                x_embed[to_remove] = 0
                x_embed[to_remove] = 0
                
                y_embed[to_remove] = 0
                y_embed[to_remove] = 0
                
                y_embed = torch.flip(y_embed,dims=[0])
                
                y_embed = y_embed.unsqueeze(0).cumsum(1, dtype=torch.float32) 
                x_embed = x_embed.unsqueeze(0).cumsum(2, dtype=torch.float32)
                
                eps = 1e-6
                y_embed = torch.flip(y_embed,dims=[1])
                y_embed = y_embed / (y_embed[:,:1, :] + eps) 
                x_embed = x_embed / (x_embed[:,:, -1:] + eps) 
                
                x_embed[0,to_remove] = 1
                y_embed[0,to_remove] = 1
                
                x_embed = x_embed * self.scale
                y_embed = y_embed * self.scale
    
        else:
            not_mask = torch.ones_like(x)
            not_mask = not_mask[:,0,...]
            y_embed = not_mask.cumsum(1, dtype=torch.float32)
            x_embed = not_mask.cumsum(2, dtype=torch.float32)
                
            if self.normalize:
                eps = 1e-6
                y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
                x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
                
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


def get_world_coordinates(img_size, H, W, intrinsics, extrinsics, d_bound):
    depth_grid = torch.arange(*d_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, H, W)
    n_depth_slices = depth_grid.shape[0]
    # import pdb; pdb.set_trace()
    # x, y grids
    x_grid = torch.linspace(0, W-1, W, dtype=torch.float)
    x_grid = x_grid.view(1, 1, W).expand(n_depth_slices, H, W)
    y_grid = torch.linspace(0, H-1, H, dtype=torch.float)
    y_grid = y_grid.view(1, H, 1).expand(n_depth_slices, H, W)
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    # frustum: [1, D, H, W, 3, 1]
    frustum = frustum.unsqueeze(0).unsqueeze(-1)
    frustum = frustum.to(intrinsics)
    
    # rescale intrinsics
    ratio_w = W / img_size[0]
    ratio_h = H / img_size[1]
    batch = intrinsics.size(0)
    views = intrinsics.size(1)
    intrinsics = intrinsics.reshape(-1, 3, 3).float()
    R = extrinsics[..., :3, :3].reshape(-1, 3, 3).float()
    T = extrinsics[..., :3, 3].reshape(-1, 3, 1).float()
    intrinsics[..., 0:1, 0:1] *= ratio_w
    intrinsics[..., 0:1, 2:3] *= ratio_w
    intrinsics[..., 1:2, 1:2] *= ratio_h
    intrinsics[..., 1:2, 2:3] *= ratio_h
    
    # image coordinates to world coordinates
    frustum = torch.cat((frustum[:, :, :, :, :2] * frustum[:, :, :, :, 2:3], frustum[:, :, :, :, 2:3]), 4)
    combined_transformation = R.matmul(torch.inverse(intrinsics))
    frustum = combined_transformation.view(batch * views, 1, 1, 1, 3, 3).matmul(frustum).squeeze(-1)
    frustum += T.view(batch * views, 1, 1, 1, 3)
    
    # rays original: X_{lidar} = R * X_{camera} + T
    frustum = frustum / torch.norm(frustum, dim=-1, keepdim=True)
    frustum = frustum.permute(0, 2, 3, 4, 1).reshape(batch, views, H, W, 3 * n_depth_slices)
    return frustum


def get_world_coordinates_(img_size, downsample, intrinsics, extrinsics, d_bound):
    # import pdb; pdb.set_trace()
    w, h = img_size
    downsampled_h, downsampled_w = h // downsample, w // downsample
    
    depth_grid = torch.arange(*d_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
    n_depth_slices = depth_grid.shape[0]
    
    # x and y grids
    x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
    x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
    y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
    y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    frustum = frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
    frustum = frustum.to(intrinsics)
    
    rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3]
    B, N, _ = translation.shape
    
    frustum = torch.cat((frustum[:, :, :, :, :, :2] * frustum[:, :, :, :, :, 2:3], frustum[:, :, :, :, :, 2:3]), 5)
    combined_transformation = rotation.matmul(torch.inverse(intrinsics))
    frustum = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(frustum).squeeze(-1)
    frustum += translation.view(B, N, 1, 1, 1, 3)
    # frustum = frustum[..., 2:3]
    # 考虑不同的normalization的方式
    # frustum = frustum / torch.norm(frustum, dim=-1, keepdim=True)
    frustum[..., 0:1] = (frustum[..., 0:1] - frustum[..., 0:1].min()) / (frustum[..., 0:1].max() - frustum[..., 0:1].min())
    frustum[..., 1:2] = (frustum[..., 1:2] - frustum[..., 1:2].min()) / (frustum[..., 1:2].max() - frustum[..., 1:2].min())
    frustum[..., 2:3] = (frustum[..., 2:3] - frustum[..., 2:3].min()) / (frustum[..., 2:3].max() - frustum[..., 2:3].min())

    # 还需要考虑只对高度进行编码，也就是下面的这种情况
    frustum = frustum[..., 2:3].permute(0, 1, 3, 4, 2, 5).view(B, N, downsampled_h, downsampled_w, n_depth_slices * 1)
    # frustum = frustum.permute(0, 1, 3, 4, 2, 5).reshape(B, N, downsampled_h, downsampled_w, n_depth_slices * 3)
    # import pdb; pdb.set_trace()
    return frustum


def get_rays_new_(img_size, downsample, intrinsics, extrinsics, d_bound):
    w, h = img_size
    downsampled_h, downsampled_w = h // downsample, w // downsample
    
    depth_grid = torch.arange(*d_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
    n_depth_slices = depth_grid.shape[0]
    
    # x and y grids
    x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
    x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
    y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
    y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    frustum = frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
    frustum = frustum.to(intrinsics)
    
    rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3]
    B, N, _ = translation.shape
    
    frustum = torch.cat((frustum[:, :, :, :, :, :2] * frustum[:, :, :, :, :, 2:3], frustum[:, :, :, :, :, 2:3]), 5)
    combined_transformation = rotation.matmul(torch.inverse(intrinsics))
    frustum = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(frustum).squeeze(-1)
    frustum += translation.view(B, N, 1, 1, 1, 3)
    # rays original: X_{lidar} = R * X_{camera} + T
    rays_o = translation.view(B, N, 1, 1, 1, 3)
    rays_d = frustum - rays_o
    # 考虑不同的normalization的方式
    rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    # rays_d[..., 0:1] = (rays_d[..., 0:1] - rays_d[..., 0:1].min()) / (rays_d[..., 0:1].max() - rays_d[..., 0:1].min())
    # rays_d[..., 1:2] = (rays_d[..., 1:2] - rays_d[..., 1:2].min()) / (rays_d[..., 1:2].max() - rays_d[..., 1:2].min())
    # rays_d[..., 2:3] = (rays_d[..., 2:3] - rays_d[..., 2:3].min()) / (rays_d[..., 2:3].max() - rays_d[..., 2:3].min())
    # 还需要考虑只对高度进行编码，也就是下面的这种情况
    # rays_d = rays_d[..., 2:3].permute(0, 1, 3, 4, 2, 5).view(B, N, downsampled_h, downsampled_w, n_depth_slices * 1)
    rays_d = rays_d.permute(0, 1, 3, 4, 2, 5).reshape(B, N, downsampled_h, downsampled_w, n_depth_slices * 3)
    
    return rays_d


def get_rays_new(img_size, H, W, intrinsics, extrinsics, d_bound):
    # wo d_bound
    ratio = W / img_size[0]
    batch = intrinsics.size(0)
    views = intrinsics.size(1)
    intrinsics = intrinsics.reshape(-1, 3, 3).float()
    R = extrinsics[..., :3, :3].reshape(-1, 3, 3).float()
    T = extrinsics[..., :3, 3].reshape(-1, 3, 1).float()
    intrinsics[:, :2] *= ratio
    rays_o = T
    j, i = torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W))
    xy1 = torch.stack([i.to(intrinsics.device), j.to(intrinsics.device), torch.ones_like(i).to(intrinsics.device)], dim=-1).unsqueeze(0)
    pixel_camera = torch.bmm(xy1.flatten(1, 2).repeat(views, 1, 1), torch.inverse(intrinsics).transpose(2, 1))
    
    
    depth_grid = torch.arange(*d_bound, dtype=torch.float)
    depth_grid = depth_grid.view(-1, 1, 1).expand(-1, H, W)
    n_depth_slices = depth_grid.shape[0]
    # import pdb; pdb.set_trace()
    # x, y grids
    x_grid = torch.linspace(0, W-1, W, dtype=torch.float)
    x_grid = x_grid.view(1, 1, W).expand(n_depth_slices, H, W)
    y_grid = torch.linspace(0, H-1, H, dtype=torch.float)
    y_grid = y_grid.view(1, H, 1).expand(n_depth_slices, H, W)
    frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
    # frustum: [1, D, H, W, 3, 1]
    frustum = frustum.unsqueeze(0).unsqueeze(-1)
    frustum = frustum.to(intrinsics)
    
    # rescale intrinsics
    ratio_w = W / img_size[0]
    ratio_h = H / img_size[1]
    batch = intrinsics.size(0)
    views = intrinsics.size(1)
    intrinsics = intrinsics.reshape(-1, 3, 3).float()
    R = extrinsics[..., :3, :3].reshape(-1, 3, 3).float()
    T = extrinsics[..., :3, 3].reshape(-1, 3, 1).float()
    intrinsics[..., 0:1, 0:1] *= ratio_w
    intrinsics[..., 0:1, 2:3] *= ratio_w
    intrinsics[..., 1:2, 1:2] *= ratio_h
    intrinsics[..., 1:2, 2:3] *= ratio_h
    
    # image coordinates to world coordinates
    frustum = torch.cat((frustum[:, :, :, :, :2] * frustum[:, :, :, :, 2:3], frustum[:, :, :, :, 2:3]), 4)
    combined_transformation = R.matmul(torch.inverse(intrinsics))
    frustum = combined_transformation.view(batch * views, 1, 1, 1, 3, 3).matmul(frustum).squeeze(-1)
    frustum += T.view(batch * views, 1, 1, 1, 3)
    
    # rays original: X_{lidar} = R * X_{camera} + T
    rays_o = T.view(batch * views, 1, 1, 1, 3)
    rays_d = frustum - rays_o
    rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    rays_d = rays_d.permute(0, 2, 3, 4, 1).reshape(batch, views, H, W, 3 * n_depth_slices)
    return rays_d


class PolarPositionEmbedding(nn.Module):
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()