"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

from typing import Callable, List, Union

import numpy as np
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
from typing import Callable, Optional
import torch.nn.functional as F

import torch.nn as nn
import math

try:
    import tinycudann as tcnn
except ImportError as e:
    print(
        f"Error: {e}! "
        "Please install tinycudann by: "
        "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
    )
    exit()

def cart2sph(coord):
    azimuth = torch.atan2(coord[...,1:2],coord[...,0:1])
    elevation = torch.atan2(coord[...,2:3],torch.sqrt(coord[...,0:1]*coord[...,0:1] + coord[...,1:2]*coord[...,1:2]))
    r = torch.sqrt(coord[...,0:1]*coord[...,0:1] + coord[...,1:2]*coord[...,1:2] + coord[...,2:3]*coord[...,2:3])
#     return torch.cat([azimuth/math.pi, elevation/math.pi],dim=-1)
    return torch.cat([azimuth/math.pi, elevation*2/math.pi],dim=-1)#, r

class _TruncExp(Function):  # pylint: disable=abstract-method
    # Implementation from torch-ngp:
    # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x):  # pylint: disable=arguments-differ
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @custom_bwd
    def backward(ctx, g):  # pylint: disable=arguments-differ
        x = ctx.saved_tensors[0]
        return g * torch.exp(torch.clamp(x, max=15))


trunc_exp = _TruncExp.apply


def contract_to_unisphere(
    x: torch.Tensor,
    aabb: torch.Tensor,
    ord: Union[str, int] = 2,
    #  ord: Union[float, int] = float("inf"),
    eps: float = 1e-6,
    derivative: bool = False,
):
    aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
    x = (x - aabb_min) / (aabb_max - aabb_min)
    x = x * 2 - 1  # aabb is at [-1, 1]
    mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
    mask = mag.squeeze(-1) > 1

    if derivative:
        dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
            1 / mag**3 - (2 * mag - 1) / mag**4
        )
        dev[~mask] = 1.0
        dev = torch.clamp(dev, min=eps)
        return dev
    else:
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
        return x


class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,  # The number of input tensor channels.
        output_dim: int = None,  # The number of output tensor channels.
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        hidden_init: Callable = nn.init.xavier_uniform_,
        hidden_activation: Callable = nn.ReLU(),
        output_enabled: bool = True,
        output_init: Optional[Callable] = nn.init.xavier_uniform_,
        output_activation: Optional[Callable] = nn.Identity(),
        bias_enabled: bool = True,
        bias_init: Callable = nn.init.zeros_,
        coord_dim: int = 0
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.net_depth = net_depth
        self.net_width = net_width
        self.skip_layer = skip_layer
        self.hidden_init = hidden_init
        self.hidden_activation = hidden_activation
        self.output_enabled = output_enabled
        self.output_init = output_init
        self.output_activation = output_activation
        self.bias_enabled = bias_enabled
        self.bias_init = bias_init
        
        self.coord_dim = coord_dim
        
        gr = 512
        ch = 64
        
        self.matMode = [[0,1], [0,2], [1,2]]
#         self.matMode = [[0,1], [0,2], [0,3], [1,2], [1,3], [2,3]]
        self.vecMode =  [2, 1, 0]

        self.gammas = nn.ParameterList()
        self.betas = nn.ParameterList()
        self.means = nn.ParameterList()
        self.vars = nn.ParameterList()
        self.bns = nn.ModuleList()
        self.meanvars = nn.ParameterList()
        self.gammas1 = nn.ModuleList()
        self.gammas2 = nn.ModuleList()
        self.gammas3 = nn.ModuleList()
        self.gammas4 = nn.ModuleList()
        self.gammas5 = nn.ModuleList()
        self.gammas6 = nn.ModuleList()
        
        self.hidden_layers = nn.ModuleList()
        in_features = self.input_dim
        for i in range(self.net_depth):
            self.hidden_layers.append(
                nn.Linear(in_features, self.net_width, bias=bias_enabled)
            )
#             self.gammas.append(nn.Parameter(torch.ones(1,1,gr,gr,gr)))
#             self.betas.append(nn.Parameter(torch.zeros(1,1,gr,gr,gr)))
            if self.coord_dim == 2:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((1, 1, 4, 12)), torch.zeros((1, 1, 4, 12))],dim=1)))
#                 self.bns.append(nn.BatchNorm1d(self.net_width, affine=False))
#                 self.means.append(nn.Parameter(torch.zeros(1)))
#                 self.vars.append(nn.Parameter(torch.ones(1)))
#                 self.meanvars.append(torch.nn.Parameter(torch.cat([torch.zeros((1, 1, 20, 6)), torch.ones((1, 1, 20, 6))],dim=1)))
            elif self.coord_dim == 3:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, gr, gr)), torch.zeros((3, 1, gr, gr))],dim=0)))
#                 self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, 2**(i+1), 2**(i+1))), torch.zeros((3, 1, 2**(i+1), 2**(i+1)))],dim=0)))
#                 self.index = torch.nn.Parameter(torch.empty(6,64,64,512))
#                 self.codebook = torch.nn.Parameter(torch.empty(512,32))
                
                
            elif self.coord_dim == 4:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((6, 1, gr, gr)), torch.zeros((6, 1, gr, gr))],dim=0)))
#             self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((1, 1, 10, 3)), torch.zeros((1, 1, 10, 3))],dim=1)))
#             self.betas.append(nn.Parameter(torch.zeros(1,1,10,3)))
            
#             self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, gr, gr)), torch.zeros((3, 1, gr, gr))],dim=0)))
#             self.betas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, gr, 1)), torch.zeros(3, 1, gr, 1)],dim=0)))
#             self.means.append(nn.Parameter(torch.zeros(1)))
#             self.vars.append(nn.Parameter(torch.ones(1)))
#             self.bns.append(nn.BatchNorm1d(self.net_width, affine=False))
#             self.pows.append(torch.nn.Parameter(torch.ones(1)))
            
            
            if (
                (self.skip_layer is not None)
                and (i % self.skip_layer == 0)
                and (i > 0)
            ):
                in_features = self.net_width + self.input_dim
            else:
                in_features = self.net_width
        if self.output_enabled:
            self.output_layer = nn.Linear(
                in_features, self.output_dim, bias=bias_enabled
            )
        else:
            self.output_dim = in_features

        self.initialize()

    def initialize(self):
        def init_func_hidden(m):
            if isinstance(m, nn.Linear):
                if self.hidden_init is not None:
                    self.hidden_init(m.weight)
                if self.bias_enabled and self.bias_init is not None:
                    self.bias_init(m.bias)

        self.hidden_layers.apply(init_func_hidden)
        if self.output_enabled:

            def init_func_output(m):
                if isinstance(m, nn.Linear):
                    if self.output_init is not None:
                        self.output_init(m.weight)
                    if self.bias_enabled and self.bias_init is not None:
                        self.bias_init(m.bias)

            self.output_layer.apply(init_func_output)

    def norm_(self, x, mean, var):
        return (x-mean) / (var + 1e-6)
    
    def forward(self, x, coord=None):
        inputs = x
#         if coord is not None and coord.shape[-1]==2:
# #                     x = self.bns[i](x)
# #                     x = self.norm_(x, self.means[i], self.vars[i])
# #             x = F.group_norm(x,1)
#             affine = F.grid_sample(self.gammas[0], coord.view(1,1,-1,2), align_corners=True).view(2,283,-1).permute(0,2,1)
#             x = affine[0]*x + affine[1]
        for i in range(self.net_depth):
            x = self.hidden_layers[i](x)
#             if coord is not None and i == self.net_depth-1:
            if coord is not None:# and i == self.net_depth-1:
                if coord.shape[-1]==2:
#                     x = self.bns[i](x)
#                     x = self.norm_(x, self.means[i], self.vars[i])
#                     x = F.group_norm(x,1)
#                     print(coord[0])
                    affine = F.grid_sample(self.gammas[i], coord[:,0].view(1,1,-1,2), align_corners=True).view(2,-1,1,1)#.permute(0,2,1)
#                     affine = F.grid_sample(self.gammas[i], coord.unsqueeze(0), align_corners=True).squeeze(0).unsqueeze(-1) #view(2,-1,1,1)#.permute(0,2,1)
#                     meanvars = F.grid_sample(self.meanvars[i], coord.view(1,1,-1,2), align_corners=True).view(2,-1,1)
#                     x = self.norm_(x, meanvars[0], meanvars[1])
                    x = affine[0]*x + affine[1]
#                     pass
                elif coord.shape[-1]==4:
                    coordinate_plane = torch.stack((coord[..., self.matMode[0]], coord[..., self.matMode[1]], coord[..., self.matMode[2]],coord[..., self.matMode[3]], coord[..., self.matMode[4]], coord[..., self.matMode[5]])).detach().view(6, -1, 1, 2).repeat(2,1,1,1)
                    plane_feats = F.grid_sample(self.gammas[i], coordinate_plane, align_corners=True)
                    gamma = torch.prod(plane_feats[0:6],dim=0,keepdim=False).view(-1,1) #squeeze().permute(1,0)
                    beta = torch.prod(plane_feats[6:],dim=0,keepdim=False).view(-1,1) #squeeze().permute(1,0)
                    x = gamma*x + beta
                else:
#                     ii = torch.argmax(self.index,dim=-1)
#                     regrid = self.codebook[ii].permute(0,3,1,2)
#                     print(regrid.shape)
#                     gamma = self.gammas1[i](coord[..., self.matMode[0]]).float()*self.gammas2[i](coord[..., self.matMode[1]]).float()*self.gammas3[i](coord[..., self.matMode[2]]).float()
#                     beta = self.gammas4[i](coord[..., self.matMode[0]]).float()*self.gammas5[i](coord[..., self.matMode[1]]).float()*self.gammas6[i](coord[..., self.matMode[2]]).float()
#                     x = gamma*x+beta
#                     gamma = self.gammas1[i](coord).permute(1,0).reshape(2,-1,1).float()
#                     x = gamma[0]*x + gamma[1]
#                     x = self.gammas1[i](coord).float()*x + self.gammas2[i](coord).float()
#                 print(coord)
#                 x = F.group_norm(x,1)
#                 x = self.norm_(x, self.means[i], self.vars[i])
#                 x = self.bns[i](x)
#                 x = F.layer_norm(x, [x.shape[-1]])
#                 x = F.grid_sample(self.gammas[i], coord.view(1,1,1,-1,3), align_corners=True).view(-1,1)*x 
#                 x = x + F.grid_sample(self.betas[i], coord.view(1,1,1,-1,3), align_corners=True).view(-1,1)

#                 x = F.grid_sample(self.gammas[i], coord.view(1,1,1,-1,3), align_corners=True).squeeze().permute(1,0)*x 
#                 x = x + F.grid_sample(self.betas[i], coord.view(1,1,1,-1,3), align_corners=True).squeeze().permute(1,0)
#                 x = F.grid_sample(self.gammas[i], coord.view(1,1,-1,2), align_corners=True).view(-1,1)*x 
#                 x = x + F.grid_sample(self.betas[i], coord.view(1,1,-1,2), align_corners=True).view(-1,1)
#                 coord = torch.pow(coord, self.pows[i])
                    cdim = coord.size()
                    coordinate_plane = torch.stack((coord[..., self.matMode[0]], coord[..., self.matMode[1]], coord[..., self.matMode[2]])).detach().view(3, cdim[0], cdim[1], 2).repeat(2,1,1,1)
    #                 coordinate_plane = torch.stack((coord[..., self.matMode[0]], coord[..., self.matMode[1]], coord[..., self.matMode[2]],coord[..., self.matMode[3]], coord[..., self.matMode[4]], coord[..., self.matMode[5]])).detach().view(6, -1, 1, 2).repeat(2,1,1,1)
    #                 coordinate_line = torch.stack((coord[..., self.vecMode[0]], coord[..., self.vecMode[1]], coord[..., self.vecMode[2]]))
    #                 coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2).repeat(2,1,1,1)

#                     plane_feats = F.grid_sample(regrid, coordinate_plane, align_corners=True)
                    plane_feats = F.grid_sample(self.gammas[i], coordinate_plane, align_corners=True)
    #                 line_feats = F.grid_sample(self.betas[i], coordinate_line, align_corners=True)
                    
#                     gamma = torch.prod(plane_feats[0:3],dim=0,keepdim=False).view(x.size(-1),-1).permute(1,0)
#                     beta = torch.prod(plane_feats[3:],dim=0,keepdim=False).view(x.size(-1),-1).permute(1,0)
                    gamma = torch.prod(plane_feats[0:3],dim=0,keepdim=False).view(cdim[0],cdim[1],1) #squeeze().permute(1,0)
                    beta = torch.prod(plane_feats[3:],dim=0,keepdim=False).view(cdim[0],cdim[1],1) #squeeze().permute(1,0)
                    x = gamma*x + beta
#                 print(coord.shape)
#                 affine = F.grid_sample(self.gammas[i], coord.view(1,1,-1,2), align_corners=True).view(2,-1,1)
#                 x = affine[0]*x + affine[1]
            x = self.hidden_activation(x)
            if (
                (self.skip_layer is not None)
                and (i % self.skip_layer == 0)
                and (i > 0)
            ):
                x = torch.cat([x, inputs], dim=-1)
        if self.output_enabled:
            x = self.output_layer(x)
            x = self.output_activation(x)
        return x


class DenseLayer(MLP):
    def __init__(self, input_dim, output_dim, **kwargs):
        super().__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            net_depth=0,  # no hidden layers
            **kwargs,
        )


class NerfMLP(nn.Module):
    def __init__(
        self,
        input_dim: int,  # The number of input tensor channels.
        condition_dim: int,  # The number of condition tensor channels.
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        net_depth_condition: int = 1,  # The depth of the second part of MLP.
        net_width_condition: int = 128,  # The width of the second part of MLP.
    ):
        super().__init__()
        self.base = MLP(
            input_dim=input_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            output_enabled=False,
            coord_dim=3,
        )
        hidden_features = self.base.output_dim
        self.sigma_layer = DenseLayer(hidden_features, 1)

        if condition_dim > 0:
            self.bottleneck_layer = DenseLayer(hidden_features, net_width)
            self.rgb_layer = MLP(
                input_dim=net_width + condition_dim,
                output_dim=3,
                net_depth=net_depth_condition,
                net_width=net_width_condition,
                skip_layer=None,
                coord_dim=3,
            )
        else:
            self.rgb_layer = DenseLayer(hidden_features, 3)

    def forward(self, x, condition=None, coord=None):
        x = self.base(x, coord[0])
#         x = self.base(x, None)
#         x = self.base(x)
        raw_sigma = self.sigma_layer(x)
        if condition is not None:
            if condition.shape[:-1] != x.shape[:-1]:
                num_rays, n_dim = condition.shape
                condition = condition.view(
                    [num_rays] + [1] * (x.dim() - condition.dim()) + [n_dim]
                ).expand(list(x.shape[:-1]) + [n_dim])
            bottleneck = self.bottleneck_layer(x)
            x = torch.cat([bottleneck, condition], dim=-1)
        raw_rgb = self.rgb_layer(x, None) #coord[1])
        return raw_rgb, raw_sigma


class SinusoidalEncoder(nn.Module):
    """Sinusoidal Positional Encoder used in Nerf."""

    def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
        super().__init__()
        self.x_dim = x_dim
        self.min_deg = min_deg
        self.max_deg = max_deg
        self.use_identity = use_identity
        self.register_buffer(
            "scales", torch.tensor([2**i for i in range(min_deg, max_deg)])
        )

    @property
    def latent_dim(self) -> int:
        return (
            int(self.use_identity) + (self.max_deg - self.min_deg) * 2
        ) * self.x_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [..., x_dim]
        Returns:
            latent: [..., latent_dim]
        """
        if self.max_deg == self.min_deg:
            return x
        xb = torch.reshape(
            (x[Ellipsis, None, :] * self.scales[:, None]),
            list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
        )
        latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
        if self.use_identity:
            latent = torch.cat([x] + [latent], dim=-1)
        return latent
    
class NGPRadianceField(torch.nn.Module):
    """Instance-NGP Radiance Field"""

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        use_viewdirs: bool = True,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        net_depth_condition: int = 1,  # The depth of the second part of MLP.
        net_width_condition: int = 128,  # The width of the second part of MLP.
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.use_viewdirs = use_viewdirs
        self.density_activation = density_activation
        self.unbounded = unbounded
        
        self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
        self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
        self.mlp = NerfMLP(
            input_dim=self.posi_encoder.latent_dim,
            condition_dim=self.view_encoder.latent_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            net_depth_condition=net_depth_condition,
            net_width_condition=net_width_condition,
        )

    
    def forward(
        self,
        x: torch.Tensor,
        condition: torch.Tensor = None,
    ):
#         print(torch.amin(torch.amin(x, dim=0),dim=0),torch.amax(torch.amax(x, dim=0),dim=0))
        if self.unbounded:
            x = contract_to_unisphere(x, self.aabb)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            x = (x - aabb_min) / (aabb_max - aabb_min) * 2.0 - 1.0
        selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
        
#         print("a", torch.amin(torch.amin(x, dim=0),dim=0),torch.amax(torch.amax(x, dim=0),dim=0))
        coord = x.clone()#.detach()*2.0-1.0
#         coord = x.clone()*3.0/2.0

        x = self.posi_encoder(x)

        if condition is not None:
            coord2 = cart2sph(condition.clone().detach())
            condition = (condition + 1.0) / 2.0
#             print("b",torch.amin(torch.amin(coord2, dim=0),dim=0))
#             coord2 = cart2sph(condition.clone())
#             print("b", torch.amin(torch.amin(coord2, dim=0),dim=0),torch.amax(torch.amax(coord2, dim=0),dim=0))
            condition = self.view_encoder(condition)
        rgb, sigma = self.mlp(x, condition=condition, coord=(coord, coord2))
        
        sigma = (
            self.density_activation(sigma)
            * selector[..., None]
        )

        return torch.sigmoid(rgb), sigma


class NGPDensityField(torch.nn.Module):
    """Instance-NGP Density Field used for resampling"""

    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        num_dim: int = 3,
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        base_resolution: int = 16,
        max_resolution: int = 128,
        n_levels: int = 5,
        log2_hashmap_size: int = 17,
    ) -> None:
        super().__init__()
        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb)
        self.num_dim = num_dim
        self.density_activation = density_activation
        self.unbounded = unbounded
        self.base_resolution = base_resolution
        self.max_resolution = max_resolution
        self.n_levels = n_levels
        self.log2_hashmap_size = log2_hashmap_size

        per_level_scale = np.exp(
            (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
        ).tolist()

        self.mlp_base = tcnn.NetworkWithInputEncoding(
            n_input_dims=num_dim,
            n_output_dims=1,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": n_levels,
                "n_features_per_level": 2,
                "log2_hashmap_size": log2_hashmap_size,
                "base_resolution": base_resolution,
                "per_level_scale": per_level_scale,
            },
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": 64,
                "n_hidden_layers": 1,
            },
        )

#         encoding = tcnn.Encoding(num_dim,
#             {
#                "otype": "HashGrid",
#                 "n_levels": n_levels,
#                 "n_features_per_level": 2,
#                 "log2_hashmap_size": log2_hashmap_size,
#                 "base_resolution": base_resolution,
#                 "per_level_scale": per_level_scale,
#             },
#         )
#         network = tcnn.Network(encoding.n_output_dims, 1,
#             {
#                 "otype": "FullyFusedMLP",
#                 "activation": "ReLU",
#                 "output_activation": "None",
#                 "n_neurons": 64,
#                 "n_hidden_layers": 1,
#             },)
#         self.mlp_base = nn.Sequential(encoding, nn.Linear(encoding.n_output_dims, 64).half(), nn.ReLU().half(), nn.Linear(64,64).half(), nn.ReLU().half(), nn.Linear(64,1).half())

#         self.coordnet = nn.Sequential(nn.Linear(3,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,3), nn.Tanh())
#         self.coordnet = tcnn.NetworkWithInputEncoding(
#             n_input_dims=3,
#             n_output_dims=3,
#             encoding_config={
#                 "otype": "Frequency",
#                 "n_frequencies": 10
#             },
#             network_config={
#                 "otype": "FullyFusedMLP",
#                 "activation": "ReLU",
#                 "output_activation": "None",
#                 "n_neurons": 128,
#                 "n_hidden_layers": 4,
#             },
#         )
        self.coordnet = tcnn.Encoding(3,
            {
                "otype": "Grid",
                "type": "Dense",
                "n_levels": 1,
                "n_features_per_level": 4,
                "base_resolution": 128,
            },)
    
        
    def forward(self, positions: torch.Tensor):
        if self.unbounded:
            positions = contract_to_unisphere(positions, self.aabb)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
            positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)

#         ch = positions.size(1)
#         positions = positions + self.coordnet(positions.view(-1,3)).view(-1,ch,4)[:,:,:3]
#         print(positions.shape)
        density_before_activation = (
            self.mlp_base(positions.view(-1, self.num_dim))
            .view(list(positions.shape[:-1]) + [1])
            .to(positions)
        )
        density = (
            self.density_activation(density_before_activation)
            * selector[..., None]
        )
        return density

# class NGPDensityField(torch.nn.Module):
#     """Instance-NGP Density Field used for resampling"""

#     def __init__(
#         self,
#         aabb: Union[torch.Tensor, List[float]],
#         num_dim: int = 3,
#         density_activation: Callable = lambda x: trunc_exp(x - 1),
#         unbounded: bool = False,
#         base_resolution: int = 16,
#         max_resolution: int = 128,
#         n_levels: int = 5,
#         log2_hashmap_size: int = 17,
#     ) -> None:
#         super().__init__()
#         if not isinstance(aabb, torch.Tensor):
#             aabb = torch.tensor(aabb, dtype=torch.float32)
#         self.register_buffer("aabb", aabb)
#         self.num_dim = num_dim
#         self.density_activation = density_activation
#         self.unbounded = unbounded
#         self.base_resolution = base_resolution
#         self.max_resolution = max_resolution
#         self.n_levels = n_levels
#         self.log2_hashmap_size = log2_hashmap_size

#         per_level_scale = np.exp(
#             (np.log(max_resolution) - np.log(base_resolution)) / (n_levels - 1)
#         ).tolist()

#         self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
#         self.mlp = MLP(
#             input_dim=self.posi_encoder.latent_dim,
#             output_dim=1,
#             net_depth=8,
#             net_width=256,
#             skip_layer=4,
#             output_enabled=True,
#             coord_dim=0,
#         )
        
#     def forward(self, positions: torch.Tensor):
#         if self.unbounded:
#             positions = contract_to_unisphere(positions, self.aabb)
#         else:
#             aabb_min, aabb_max = torch.split(self.aabb, self.num_dim, dim=-1)
#             positions = (positions - aabb_min) / (aabb_max - aabb_min) * 2.0 - 1.0
#         selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        
#         positions = self.posi_encoder(2.0*positions-1.0)
# #         ch = positions.size(1)
# #         positions = positions + self.coordnet(positions.view(-1,3)).view(-1,ch,4)[:,:,:3]
# #         print(positions.shape)
#         density_before_activation = self.mlp(positions, None)
#         density = (
#             self.density_activation(density_before_activation)
#             * selector[..., None]
#         )
#         return density
