import torch
import kaolin
import math
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer


class CameraModule():
    def __init__(self):
        self.bg_color = torch.tensor([1.0] * 32).float()
        self.scale_modifier = 1.0

    def perspective_camera(self, points, camera_proj):
        # camera_proj：形状为 (B, 3, 3) 的相机投影矩阵，用于将三维点投影到二维图像平面
        # 对每个批次的点集和投影矩阵进行矩阵乘法
        projected_points = torch.bmm(points, camera_proj.permute(0, 2, 1))
        # projected_points[:, :, :2] 提取投影后的 x 和 y 坐标
        # projected_points[:, :, 2:3] 提取投影后的 z 坐标，并将其形状从 (B, N) 转换为 (B, N, 1)，以便进行广播除法
        # 通过除以 z 坐标，实现透视除法，将三维点转换为二维点
        projected_2d_points = projected_points[:, :, :2] / projected_points[:, :, 2:3]

        return projected_2d_points

    def prepare_vertices(self, vertices, faces, camera_proj, camera_rot=None, camera_trans=None,
                     camera_transform=None):
        # 如果 camera_transform 为 None，则使用旋转矩阵 camera_rot 和平移向量 camera_trans 对顶点进行旋转和平移变换
        # 将顶点从世界坐标系转换到相机坐标系
        if camera_transform is None:
            assert camera_trans is not None and camera_rot is not None, \
                "camera_transform or camera_trans and camera_rot must be defined"
            vertices_camera = kaolin.render.camera.rotate_translate_points(vertices, camera_rot,
                                                            camera_trans)
        # 如果 camera_transform 不为 None，则直接使用 camera_transform 对顶点进行变换
        # 将顶点从世界坐标系转换到相机坐标系
        # 首先将顶点扩展为齐次坐标，然后进行矩阵乘法
        else:
            assert camera_trans is None and camera_rot is None, \
                "camera_trans and camera_rot must be None when camera_transform is defined"
            padded_vertices = torch.nn.functional.pad(
                vertices, (0, 1), mode='constant', value=1.
            )
            vertices_camera = (padded_vertices @ camera_transform)
        # Project the vertices on the camera image plan
        # 将顶点以及相机投影矩阵传入 perspective_camera 函数进行透视投影
        vertices_image = self.perspective_camera(vertices_camera, camera_proj)
        # 使用 kaolin.ops.mesh.index_vertices_by_faces 函数，提取相机坐标系中的面顶点坐标,用于计算法线向量
        face_vertices_camera = kaolin.ops.mesh.index_vertices_by_faces(vertices_camera, faces)
        # 使用 kaolin.ops.mesh.index_vertices_by_faces 函数，提取图像坐标系中的面顶点坐标,用于渲染
        face_vertices_image = kaolin.ops.mesh.index_vertices_by_faces(vertices_image, faces)
        # 使用 kaolin.ops.mesh.face_normals 函数，计算每个面的法线向量
        face_normals = kaolin.ops.mesh.face_normals(face_vertices_camera, unit=True)
        return face_vertices_camera, face_vertices_image, face_normals
    
    def render_mesh(self, data, resolution):
        # 从数据中提取顶点、面、顶点颜色列表
        verts_list = data['verts_list']
        faces_list = data['faces_list']
        verts_color_list = data['verts_color_list']

        B = len(verts_list)

        render_images = []
        render_soft_masks = []
        render_depths = []
        render_normals = []
        face_normals_list = []
        # 渲染每个批次的数据
        for b in range(B):
            # 获取每个batch的相机内参、外参
            intrinsics = data['intrinsics'][b]
            extrinsics = data['extrinsics'][b]
            
            # 设置相机投影矩阵 camera_proj 和变换矩阵 camera_transform
            camera_proj = intrinsics
            camera_transform = extrinsics.permute(0, 2, 1)

            # 复制顶点、面和顶点颜色数据，以适应相机内参的形状
            verts = verts_list[b].unsqueeze(0).repeat(intrinsics.shape[0], 1, 1)
            # 面索引，表示每个面由哪些顶点组成
            faces = faces_list[b]
            verts_color = verts_color_list[b].unsqueeze(0).repeat(intrinsics.shape[0], 1, 1)
            # 根据面索引提取每个面的顶点颜色
            # faces_color = verts_color[:, faces] 根据面索引提取每个面的顶点颜色，
            # 结果形状为 (K, M, 3, 3)，其中 M 是面的数量，每个面有三个顶点，每个顶点有三个颜色通道
            faces_color = verts_color[:, faces]

            # 调用 prepare_vertices 函数，进行透视投影,并获取相机坐标与平面坐标的顶点以及法线向量
            face_vertices_camera, face_vertices_image, face_normals = self.prepare_vertices(
                verts, faces, camera_proj, camera_transform=camera_transform
            )
            # 将图像平面上的面顶点坐标的 y 轴方向反转
            face_vertices_image[:, :, :, 1] = -face_vertices_image[:, :, :, 1]
            # 将面法线向量的 y 和 z 轴方向反转 
            face_normals[:, :, 1:] = -face_normals[:, :, 1:]
            ### Perform Rasterization ###
            # Construct attributes that DI1-R rasterizer will interpolate.
            # the first is the UVS associated to each face
            # the second will make a hard segmentation mask
            
            # 构建面属性列表，包括面颜色、硬分割掩码、面顶点深度和法线向量
            face_attributes = [
                faces_color,
                # 用于在光栅化过程中区分前景和背景的硬分割掩码
                torch.ones((faces_color.shape[0], faces_color.shape[1], 3, 1), device=verts.device),
                face_vertices_camera[:, :, :, 2:],
                face_normals.unsqueeze(-2).repeat(1, 1, 3, 1),
            ]

            # 调用 kaolin.render.mesh.dibr_rasterization 函数进行光栅化和渲染，得到图像特征、软掩码和面索引
            image_features, soft_masks, face_idx = kaolin.render.mesh.dibr_rasterization(
                resolution, resolution, -face_vertices_camera[:, :, :, -1],
                face_vertices_image, face_attributes, face_normals[:, :, -1],
                rast_backend='cuda')

            # image_features is a tuple in composed of the interpolated attributes of face_attributes
            # 提取图像、掩码、深度图和法线图
            images, masks, depths, normals = image_features
            # 将图像与掩码相乘，得到前景图像
            # torch.clamp(images * masks, 0., 1.)：将图像的像素值限制在 [0, 1] 范围内
            images = torch.clamp(images * masks, 0., 1.)
            # 将深度图与掩码相乘，以去除背景区域
            depths = (depths * masks)
            # 将法线图与掩码相乘，以去除背景区域
            normals = (normals * masks)
            
            render_images.append(images)
            render_soft_masks.append(soft_masks)
            render_depths.append(depths)
            render_normals.append(normals)
            face_normals_list.append(face_normals)

        render_images = torch.stack(render_images, 0)
        render_soft_masks = torch.stack(render_soft_masks, 0)
        render_depths = torch.stack(render_depths, 0)
        render_normals = torch.stack(render_normals, 0)

        data['render_images'] = render_images
        data['render_soft_masks'] = render_soft_masks
        data['render_depths'] = render_depths
        data['render_normals'] = render_normals
        data['verts_list'] = verts_list
        data['faces_list'] = faces_list
        data['face_normals_list'] = face_normals_list

        return data
    

    def render_gaussian(self, data, resolution):
        """
        Render the scene. 
        
        Background tensor (bg_color) must be on GPU!
        """
        B = data['xyz'].shape[0]
        # 获取3D坐标
        xyz = data['xyz']
        #shs = rearrange(data['shs'], 'b n (x y) -> b n x y', y=3)
        # 获取颜色、透明度、缩放、旋转矩阵
        colors_precomp = data['color']
        opacity = data['opacity']
        scales = data['scales']
        rotations = data['rotation']
        # 水平和垂直视场角，用于计算投影矩阵
        fovx = data['fovx']
        fovy = data['fovy']
        # 获取世界视图变换矩阵，将3D坐标转换到相机坐标系
        world_view_transform = data['world_view_transform']
        # 获取完整的投影变换矩阵，将相机坐标系转换到投影坐标系
        full_proj_transform = data['full_proj_transform']
        # 相机中心位置。表示相机在世界坐标系中的位置，用于计算视图变换矩阵
        camera_center = data['camera_center']
    
        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
        # 创建一个与 xyz 形状相同的零张量 screenspace_points，用于存储屏幕空间的点，并启用梯度计算
        screenspace_points = torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device=xyz.device) + 0
        try:
            screenspace_points.retain_grad()
        except:
            pass

        render_images = []
        radii = []
        for b in range(B):
            # 计算水平以及垂直视场角的tan值
            tanfovx = math.tan(fovx[b] * 0.5)
            tanfovy = math.tan(fovy[b] * 0.5)
            # 创建 GaussianRasterizationSettings 对象
            raster_settings = GaussianRasterizationSettings(
                image_height=int(resolution),
                image_width=int(resolution),
                tanfovx=tanfovx,
                tanfovy=tanfovy,
                bg=self.bg_color.to(xyz.device),
                scale_modifier=self.scale_modifier,
                viewmatrix=world_view_transform[b],
                projmatrix=full_proj_transform[b],
                sh_degree=0,
                campos=camera_center[b],
                prefiltered=False,
                debug=False,
                antialiasing=True,
            )
            # 创建 GaussianRasterizer 对象，使用上述配置进行初始化
            rasterizer = GaussianRasterizer(raster_settings=raster_settings)

            # 提取当前批次的3D 坐标-->means3D 和屏幕空间点-->means2D
            means3D = xyz[b]
            means2D = screenspace_points[b]
            
            # Rasterize visible Gaussians to image, obtain their radii (on screen). 
            # 调用 rasterizer 进行光栅化，得到渲染图像和半径
            render_images_b, radii_b, _ = rasterizer(
                means3D = means3D,
                means2D = means2D,
                #shs = shs[b],
                colors_precomp = colors_precomp[b],
                opacities = opacity[b],
                scales = scales[b],
                rotations = rotations[b])
            render_images.append(render_images_b)
            radii.append(radii_b)

        render_images = torch.stack(render_images)
        radii = torch.stack(radii)
        data['render_images'] = render_images
        data['viewspace_points'] =  screenspace_points
        data['visibility_filter'] = radii > 0
        data['radii'] = radii
        return data
    
    
    def _apply_adaptive_smoothing(self, current_xyz, history_xyz=None, point_weights=None, smoothing_strength=0.7, method='exponential'):
        """
        Apply adaptive smoothing to XYZ coordinates
        
        Args:
            current_xyz: Current frame's XYZ coordinates
            history_xyz: List of previous frames' XYZ coordinates
            point_weights: Optional tensor of weights per point (1.0 = full smoothing, 0.0 = no smoothing)
            smoothing_strength: Strength of smoothing effect (0.0 to 1.0)
            method: Smoothing method - 'average', 'exponential', or 'gaussian'
            
        Returns:
            Smoothed XYZ coordinates
        """
        if history_xyz is None or len(history_xyz) == 0:
            return current_xyz
            
        # Default weights if not provided - can be based on semantic knowledge
        if point_weights is None:
            # Default: uniform weights - match the dimensionality of current_xyz
            if current_xyz.dim() == 3:  # [batch, points, 3]
                point_weights = torch.ones_like(current_xyz[:, :, 0:1])
            else:  # [points, 3]
                point_weights = torch.ones_like(current_xyz[:, 0:1])
            
        # Stack previous frames for easier computation - match the dimensionality
        history_stack = torch.stack(history_xyz, dim=0)  # [history_len, batch/points, 3]
        
        # Apply different smoothing methods
        if method == 'average':
            # Simple moving average
            history_mean = torch.mean(history_stack, dim=0)
            smoothed_xyz = (1 - smoothing_strength * point_weights) * current_xyz + \
                          (smoothing_strength * point_weights) * history_mean
                           
        elif method == 'exponential':
            # Exponential smoothing (more weight to recent frames)
            decay = torch.tensor([math.exp(-i * 0.5) for i in range(len(history_xyz))], 
                                device=current_xyz.device)
            decay = decay / decay.sum()
            
            weighted_history = torch.zeros_like(current_xyz)
            for i, hist in enumerate(history_xyz):
                weighted_history += decay[i] * hist
                
            smoothed_xyz = (1 - smoothing_strength * point_weights) * current_xyz + \
                          (smoothing_strength * point_weights) * weighted_history
                           
        elif method == 'gaussian':
            # Gaussian filter
            sigma = 1.0
            weights = torch.tensor([math.exp(-(i**2)/(2*sigma**2)) for i in range(len(history_xyz))], 
                                  device=current_xyz.device)
            weights = weights / weights.sum()
            
            weighted_history = torch.zeros_like(current_xyz)
            for i, hist in enumerate(history_xyz):
                weighted_history += weights[i] * hist
                
            smoothed_xyz = (1 - smoothing_strength * point_weights) * current_xyz + \
                          (smoothing_strength * point_weights) * weighted_history
        else:
            smoothed_xyz = current_xyz
            
        return smoothed_xyz
    
    def _get_adaptive_weights(self, xyz):
        """
        Calculate per-point smoothing weights based on semantic regions
        """
        # Default: uniform weights - match dimensionality of input
        if xyz.dim() == 3:  # [batch, points, 3]
            weights = torch.ones_like(xyz[:, :, 0:1]) 
        else:  # [points, 3]
            weights = torch.ones_like(xyz[:, 0:1])
        return weights

    def render_gaussian_smoothed(self, data, resolution, history_xyz=None, smoothing_config=None):
        """
        Render the scene with temporal smoothing of XYZ coordinates.
        
        Args:
            data: Dictionary containing rendering data
            resolution: Rendering resolution
            history_xyz: List of previous xyz coordinates for smoothing
            smoothing_config: Dict containing smoothing parameters:
                - enabled: Whether to apply smoothing
                - method: 'average', 'exponential', or 'gaussian'
                - strength: Value between 0.0 and 1.0
        
        Background tensor (bg_color) must be on GPU!
        """
        B = data['xyz'].shape[0]
        # Get 3D coordinates
        xyz = data['xyz']
        
        # Set default smoothing configuration if not provided
        if smoothing_config is None:
            smoothing_config = {
                'enabled': True,
                'method': 'exponential',
                'strength': 0.7
            }
        
        # Apply temporal smoothing to xyz coordinates if enabled
        smoothed_xyz = xyz.clone()  # Start with a copy

        
        # If we have history and smoothing is enabled, apply smoothing
        if history_xyz is not None and len(history_xyz) > 0 and smoothing_config.get('enabled', True):
            # Apply smoothing to the entire batch at once instead of per-batch item
            region_weights = self._get_adaptive_weights(xyz)
            
            # Convert history to the right format
            history_tensors = history_xyz  # Already a list of tensors
            
            smoothed_xyz = self._apply_adaptive_smoothing(
                xyz,
                history_tensors,
                region_weights,
                smoothing_config.get('strength', 0.7),
                smoothing_config.get('method', 'exponential')
            )
        
        # Store original xyz and replace with smoothed version
        data['original_xyz'] = xyz
        data['xyz'] = smoothed_xyz
        
        colors_precomp = data['color']
        opacity = data['opacity']
        scales = data['scales']
        rotations = data['rotation']
        fovx = data['fovx']
        fovy = data['fovy']
        world_view_transform = data['world_view_transform']
        full_proj_transform = data['full_proj_transform']
        camera_center = data['camera_center']
    
        screenspace_points = torch.zeros_like(smoothed_xyz, dtype=smoothed_xyz.dtype, requires_grad=True, device=smoothed_xyz.device) + 0
        try:
            screenspace_points.retain_grad()
        except:
            pass

        render_images = []
        radii = []
        for b in range(B):
            tanfovx = math.tan(fovx[b] * 0.5)
            tanfovy = math.tan(fovy[b] * 0.5)
            raster_settings = GaussianRasterizationSettings(
                image_height=int(resolution),
                image_width=int(resolution),
                tanfovx=tanfovx,
                tanfovy=tanfovy,
                bg=self.bg_color.to(smoothed_xyz.device),
                scale_modifier=self.scale_modifier,
                viewmatrix=world_view_transform[b],
                projmatrix=full_proj_transform[b],
                sh_degree=0,
                campos=camera_center[b],
                prefiltered=False,
                debug=False,
                antialiasing=True,
            )
            rasterizer = GaussianRasterizer(raster_settings=raster_settings)

            means3D = smoothed_xyz[b]
            means2D = screenspace_points[b]
            
            render_images_b, radii_b, _ = rasterizer(
                means3D = means3D,
                means2D = means2D,
                colors_precomp = colors_precomp[b],
                opacities = opacity[b],
                scales = scales[b],
                rotations = rotations[b])
            render_images.append(render_images_b)
            radii.append(radii_b)

        render_images = torch.stack(render_images)
        radii = torch.stack(radii)
        data['render_images'] = render_images
        data['viewspace_points'] =  screenspace_points
        data['visibility_filter'] = radii > 0
        data['radii'] = radii
        
        # Restore original xyz
        data['xyz'] = data.pop('original_xyz')
        
        return data