import math
import trimesh
import numpy as np
import torch
import torch.nn as nn
from lidarnerf import raymarching
from lidarnerf.dataset.base_dataset import get_lidar_rays


def sample_pdf(bins, weights, n_samples, det=False):
    # This implementation is from NeRF
    # bins: [B, T], old_z_vals
    # weights: [B, T - 1], bin weights.
    # return: [B, n_samples], new_z_vals

    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
    # Take uniform samples
    if det:
        u = torch.linspace(
            0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples
        ).to(weights.device)
        u = u.expand(list(cdf.shape[:-1]) + [n_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples
def plot_pointcloud(pc, color=None):
    # pc: [N, 3]
    # color: [N, 3/4]
    print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0))
    pc = trimesh.PointCloud(pc, color)
    # axis
    axes = trimesh.creation.axis(axis_length=4)
    # sphere
    sphere = trimesh.creation.icosphere(radius=1)
    trimesh.Scene([pc, axes, sphere]).show()
class Lie():
    def compose_pair(self,pose_a,pose_b):
        pose_new=torch.matmul(pose_b, pose_a).to(dtype=torch.float32)
        '''
        R_a,t_a = pose_a[...,:3,:3],pose_a[...,:3,3:]
        R_b,t_b = pose_b[...,:3],pose_b[...,3:]
        R_new = R_b@R_a
        t_new = (R_b@t_a+t_b)[...,0]      
        '''
        return pose_new
    

    def so3_to_SO3(self,w): # [...,3]
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[...,None,None]
        I = torch.eye(3,device=w.device,dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        #R=(1-cos(theta))/theta a^a^+I+sin(theta)/theta a^
        R = I+A*wx+B*wx@wx
        return R

    def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
        trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
        theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
        lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
        w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
        w = torch.stack([w0,w1,w2],dim=-1)
        return w

    def se3_to_SE3(self,wu,device): # [...,3]
        w,u = wu.split([3,3],dim=-1)
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[...,None,None]
        I = torch.eye(3,device=w.device,dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        C = self.taylor_C(theta)
        R = I+A*wx+B*wx@wx
        V = I+B*wx+C*wx@wx
        Rt = torch.cat([R,(u[...,None])],dim=-1)
        pad=torch.tensor([0,0,0,1]).unsqueeze(0).to(device) #1,4
        Rt=torch.cat([Rt,pad],dim=0)
        return Rt

    def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
        R,t = Rt.split([3,1],dim=-1)
        w = self.SO3_to_so3(R)
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[...,None,None]
        I = torch.eye(3,device=w.device,dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
        u = (invV@t)[...,0]
        wu = torch.cat([w,u],dim=-1)
        return wu    

    def skew_symmetric(self,w):
        w0,w1,w2 = w.unbind(dim=-1)
        O = torch.zeros_like(w0)
        wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
                          torch.stack([w2,O,-w0],dim=-1),
                          torch.stack([-w1,w0,O],dim=-1)],dim=-2)
        return wx

    def taylor_A(self,x,nth=10):
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            if i>0: denom *= (2*i)*(2*i+1)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans
    def taylor_B(self,x,nth=10):
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            denom *= (2*i+1)*(2*i+2)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans
    def taylor_C(self,x,nth=10):
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            denom *= (2*i+2)*(2*i+3)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans

class NeRFRenderer(nn.Module):
    def __init__(
        self,
        bound=1,
        density_scale=1,  # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
        min_near=0.2,
        min_near_lidar=0.2,
        density_thresh=0.01,
        bg_radius=-1,
        #device="cuda:1"
    ):
        super().__init__()
        self.lie=Lie()
        self.bound = bound
        self.cascade = 1 + math.ceil(math.log2(bound))
        self.grid_size = 128
        self.density_scale = density_scale
        self.min_near = min_near
        self.min_near_lidar = min_near_lidar
        self.density_thresh = density_thresh
        self.bg_radius = bg_radius  

        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
        aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
        aabb_infer = aabb_train.clone()
        self.register_buffer("aabb_train", aabb_train)
        self.register_buffer("aabb_infer", aabb_infer)

    def forward(self, x, d):
        raise NotImplementedError()
    def density(self, x):
        raise NotImplementedError()
    def color(self, x, d, mask=None, **kwargs):
        raise NotImplementedError()
    def save_pose(self,idx,pose):
        raise NotImplementedError()
    def get_pose(self,idx,pose):
        raise NotImplementedError()   
    def image_to_rays(self,data,pose):
        rays_lidar = get_lidar_rays(
                pose,
                data["intrinsics_lidar"],
                data["H_lidar"],
                data["W_lidar"],
                data["num_rays_lidar"],
                data["patch"],
            )
        return rays_lidar

    def run(
        self,
        data,        
        rays_o,
        rays_d,
        image_lidar_sample_rays,
        cal_lidar_color=False,
        num_steps=128,        
        upsample_steps=128,
        bg_color=None,
        perturb=False,
        **kwargs
    ):

        if cal_lidar_color:
            self.out_dim = self.out_lidar_color_dim
        else:
            self.out_dim = self.out_color_dim
        prefix = rays_o.shape[:-1]
        rays_o = rays_o.contiguous().view(-1, 3)
        rays_d = rays_d.contiguous().view(-1, 3)
        N = rays_o.shape[0]  # N = B * N, in fact
        device = rays_o.device
        aabb = self.aabb_train if self.training else self.aabb_infer

        if cal_lidar_color:
            nears = (
                torch.ones(N, dtype=rays_o.dtype, device=rays_o.device)
                * self.min_near_lidar
            )
            fars = (
                torch.ones(N, dtype=rays_o.dtype, device=rays_o.device)
                * self.min_near_lidar
                * 81.0
            ) 
        else:
            nears, fars = raymarching.near_far_from_aabb(
                rays_o, rays_d, aabb, self.min_near
            )
        nears.unsqueeze_(-1)
        fars.unsqueeze_(-1)
        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(
            0
        )  

        z_vals = z_vals.expand((N, num_steps))  
        z_vals = nears + (fars - nears) * z_vals  
        sample_dist = (fars - nears) / num_steps
        if perturb:
            z_vals = (
                z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
            )

        xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
            -1
        )  

        xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:])  
        density_outputs=self.density(xyzs.reshape(-1, 3))
        for k, v in density_outputs.items():
            density_outputs[k] = v.view(N, num_steps, -1)
        deltas = z_vals[..., 1:] - z_vals[..., :-1]  
        deltas = torch.cat(
            [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1
        )
        alphas = 1 - torch.exp(
            -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1)
        )  
        alphas_shifted = torch.cat(
            [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1
        )  
        weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]  

        dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
        for k, v in density_outputs.items():
            density_outputs[k] = v.view(-1, v.shape[-1])

        mask = weights > 1e-4  
        
        rgbs = self.color(
            xyzs.reshape(-1, 3),
            dirs.reshape(-1, 3),
            cal_lidar_color=cal_lidar_color,
            mask=mask.reshape(-1),
            **density_outputs
        )

        rgbs = rgbs.view(N, -1, self.out_dim)  


        weights_sum = weights.sum(dim=-1)  
        depth = torch.sum(weights * z_vals, dim=-1)
        intensity = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2)  

        if self.bg_radius > 0:
            sph = raymarching.sph_from_ray(
                rays_o, rays_d, self.bg_radius
            )  
            bg_color = self.background(sph, rays_d.reshape(-1, 3))  
        elif bg_color is None:
            bg_color = 1

        if not cal_lidar_color:
            intensity = intensity + (1 - weights_sum).unsqueeze(-1) * bg_color

        intensity = intensity.view(*prefix, self.out_dim)
        depth = depth.view(*prefix)
        
        return {
            "image":data["image"],
            "image_lidar_sample_rays":image_lidar_sample_rays,
            "depth_lidar": depth,
            "intensity": intensity,
            "weights_sum_lidar": weights_sum,
        }
    

    def render(
        self,
        data,
        training=True,
        cal_lidar_color=False,
        staged=False,
        max_ray_batch=4096,
        **kwargs
    ):

        pose=self.get_pose(data["index"],data["pose"])
        rays_lidar=self.image_to_rays(data,pose)
        image_lidar=data["image"]
        B=  image_lidar.shape[0]
        C = image_lidar.shape[-1]
        image_lidar_sample_rays = torch.gather(
            image_lidar.view(B, -1, C),
            1,
            torch.stack(C * [rays_lidar["inds"]], -1),
        )  
        rays_o = rays_lidar["rays_o"]  
        rays_d = rays_lidar["rays_d"]  
        _run=self.run
        device = rays_o.device
        B,N=rays_o.shape[0:2]
        if staged:
            if cal_lidar_color:
                out_dim = self.out_lidar_color_dim
                res_keys = ["depth_lidar", "image_lidar"]
            depth = torch.empty((B, N), device=device)
            intensity = torch.empty((B, N, out_dim), device=device)
            for b in range(B):
                head = 0
                while head < N:
                    tail = min(head + max_ray_batch, N)
                    results_ = _run(
                        data,
                        rays_o[b : b + 1, head:tail],
                        rays_d[b : b + 1, head:tail],
                        image_lidar_sample_rays,
                        cal_lidar_color=cal_lidar_color,
                        **kwargs
                    )
                    depth[b : b + 1, head:tail] = results_["depth_lidar"]
                    intensity[b : b + 1, head:tail] = results_["intensity"]
                    head += max_ray_batch
            results = {}
            results["depth_lidar"] = depth
            results["intensity"] = intensity
            results["image"] = results_["image"]
            results["image_lidar_sample_rays"]=image_lidar_sample_rays

        else:
            results = _run(data,rays_o,rays_d,image_lidar_sample_rays,cal_lidar_color=cal_lidar_color, max_ray_batch=max_ray_batch, **kwargs)
        
        return results