import torch
import torch.nn as nn
from diff_gaussian_rasterization_depth import GaussianRasterizationSettings, GaussianRasterizer
from models.gaussian_splatting.scene.gaussian_model import GaussianModel
import models
import math
import numpy as np
from utils.misc import config_to_primitive, get_rank

def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    return 2*math.atan(pixels/(2*focal))

def getWorld2View2(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0):
    Rt = torch.eye(4, device=R.device)
    Rt[:3, :3] = R.transpose(0, 1)
    Rt[:3, 3] = t
    C2W = torch.inverse(Rt)
    cam_center = C2W[:3, 3]
    cam_center = (cam_center + translate) * scale
    C2W[:3, 3] = cam_center
    Rt = torch.inverse(C2W)
    return Rt

def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4).cuda()  # Make sure the tensor is on the same device

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

@models.register('gs')
class GSModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.rank = get_rank()
        self.setup()
        self.var_curve = []

    def transform_c2w(self, c2w): 
        zfar = 100.0
        znear = 0.01
        device = c2w.device  # Get the device from the input tensor
        bottom_row = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device)
        c2w = torch.cat((c2w, bottom_row), dim=0)
        # Change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
        c2w[:3, 1:3] *= -1
        # Get the world-to-camera transform and set R, T
        c2w
        w2c = torch.inverse(c2w)
        R = torch.transpose(w2c[:3, :3], 0, 1)  # R is stored transposed due to 'glm' in CUDA code
        T = w2c[:3, 3]
        trans = torch.tensor([0.0, 0.0, 0.0], device=device)
        scale = 1.0
        world_view_transform = getWorld2View2(R, T, trans, scale).to(device).transpose(0, 1)
        projection_matrix = getProjectionMatrix(znear, zfar, self.fovx, self.fovy).to(device).transpose(0, 1)
        full_proj_transform = world_view_transform.unsqueeze(0) @ projection_matrix.unsqueeze(0)
        full_proj_transform = full_proj_transform.squeeze(0)
        return world_view_transform, full_proj_transform

    def render(self, c2w, bg_color : torch.Tensor, scaling_modifier = 1.0):
        """
        Render the scene. 
        
        Background tensor (bg_color) must be on GPU!
        """
        pc = self.gaussians
        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means

        # Set up rasterization configuration
        world_view_transform, full_proj_transform, = self.transform_c2w(c2w)
        camera_center = world_view_transform.inverse()[3, :3]

        raster_settings = GaussianRasterizationSettings(
            image_height=self.h,
            image_width=self.w,
            tanfovx=self.tanfovx,
            tanfovy=self.tanfovy,
            bg=bg_color,
            scale_modifier=scaling_modifier,
            viewmatrix=world_view_transform,
            projmatrix=full_proj_transform,
            sh_degree=pc.active_sh_degree,
            campos=camera_center,
            prefiltered=False,
            debug=False, 
        )

        rasterizer = GaussianRasterizer(raster_settings=raster_settings)

        # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
        # scaling / rotation by the rasterizer.
        cov3D_precomp = None
        colors_precomp = None

        # Rasterize visible Gaussians to image, obtain their radii (on screen). 
        total_gs = pc.get_xyz.shape[0]

        self.mc_mask = torch.ones(total_gs).to(self.rank).bool()
        scales = pc.get_scaling
        rotations = pc.get_rotation
        means3D = pc.get_xyz
        opacity = pc.get_opacity
        shs = pc.get_features

        screenspace_points = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") + 0
        try:
            screenspace_points.retain_grad()
        except:
            pass
        means2D = screenspace_points

        rendered_image, radii, depth = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = cov3D_precomp)

        # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
        # They will be excluded from value updates used in the splitting criteria.
        # print(rendered_image)

        visibility_filter = radii > 0

        rendered_image = rendered_image.permute(1,2,0) # radii shape same as n points
        #print(radii.shape, "radii.shape")
        return {"comp_rgb": rendered_image,
                "viewspace_points": screenspace_points,
                "visibility_filter" : visibility_filter,
                "opacity" : visibility_filter,
                "radii": radii, 
                "scales": scales, 
                "depth": depth.reshape(-1)}

    def forward(self, rays, c2w, ray_chunk=None, num_samples_per_ray=None): 
        out = self.render(c2w, self.background_color)

        out.update({'rays_valid': out["visibility_filter"],
                    'num_samples': torch.as_tensor([len(out["opacity"])], dtype=torch.int32).to(self.rank)
        })
        # depth out 
        return {
            **out,
            **{k + '_full': v for k, v in out.items()}
        }

    def forward_k_times(self, rays, c2w, ray_chunk=None, num_samples_per_ray=None, k=1): 
        rgbs = []
        depths = []
        for i in range(k): 
            out = self.render(c2w, self.background_color)
            rgb = out['comp_rgb']
            depth = out['depth']
            depths.append(depth)
            rgbs.append(rgb)
        rgbs = torch.stack(rgbs, dim=0)
        depths = torch.stack(depths, dim=0)
        depth_mean = depths.mean(dim=0)
        depth_std = depths.std(dim=0)
        depth_var = depths.var(dim=0)

        std = rgbs.std(dim=0)
        var = rgbs.var(dim=0)
        print(rgbs.shape, "rgbs.shape")
        mean = rgbs.mean(dim=0)
        out.update({'comp_rgb': mean,
                    'comp_rgbs': rgbs, 
                    'comp_var': var, 
                    'comp_std': std, 
                    'depth_var': depth_var, 
                    'depth_mean': depth_mean, 
                    'rays_valid': out["visibility_filter"],
                    'num_samples': torch.as_tensor([len(out["opacity"])], dtype=torch.int32).to(self.rank)
        })
        # depth out 
        return {
            **out,
            **{k + '_full': v for k, v in out.items()}
        }

    def setup(self):
        self.sh_degree = 3
        self.gaussians = GaussianModel(self.sh_degree)
        self.background_color = None
        self.fake_param = torch.nn.Parameter(torch.zeros(1))

    def update_step(self, epoch, global_step):
        pass
    
    def train(self, mode=True):
        return super().train(mode=mode)
    
    def eval(self):
        return super().eval()
    
    def regularizations(self, out):
        return {}
    
    @torch.no_grad()
    def export(self, export_config):
        return {}
