"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

import functools
import math
from typing import Callable, Optional, List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from siren_pytorch import SirenNet
import tinycudann as tcnn
from einops import rearrange
import numpy as np

# def cart2sph(coord):
#     azimuth = torch.atan2(coord[:,1:2],coord[:,0:1])
#     elevation = torch.atan2(coord[:,2:3],torch.sqrt(coord[:,0:1]*coord[:,0:1] + coord[:,1:2]*coord[:,1:2]))
#     r = torch.sqrt(coord[:,0:1]*coord[:,0:1] + coord[:,1:2]*coord[:,1:2] + coord[:,2:3]*coord[:,2:3])
# #     return torch.cat([azimuth/math.pi, elevation/math.pi],dim=-1)
#     return torch.cat([azimuth/math.pi, elevation*(-4)/math.pi-1],dim=-1)#, r

def contract_to_unisphere(
    x: torch.Tensor,
    aabb: torch.Tensor,
    ord: Union[str, int] = 2,
    #  ord: Union[float, int] = float("inf"),
    eps: float = 1e-6,
    derivative: bool = False,
):
    aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
    x = (x - aabb_min) / (aabb_max - aabb_min)
    x = x * 2 - 1  # aabb is at [-1, 1]
    mag = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
    mask = mag.squeeze(-1) > 1

    if derivative:
        dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
            1 / mag**3 - (2 * mag - 1) / mag**4
        )
        dev[~mask] = 1.0
        dev = torch.clamp(dev, min=eps)
        return dev
    else:
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
        return x
    
def cart2sph(coord):
    azimuth = torch.atan2(coord[:,1:2],coord[:,0:1])
    elevation = torch.atan2(coord[:,2:3],torch.sqrt(coord[:,0:1]*coord[:,0:1] + coord[:,1:2]*coord[:,1:2]))
    r = torch.sqrt(coord[:,0:1]*coord[:,0:1] + coord[:,1:2]*coord[:,1:2] + coord[:,2:3]*coord[:,2:3])
#     return torch.cat([azimuth/math.pi, elevation/math.pi],dim=-1)
    return torch.cat([azimuth/math.pi, elevation/math.pi*2.0],dim=-1)#, r

class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,  # The number of input tensor channels.
        output_dim: int = None,  # The number of output tensor channels.
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        hidden_init: Callable = nn.init.xavier_uniform_,
        hidden_activation: Callable = nn.ReLU(),
        output_enabled: bool = True,
        output_init: Optional[Callable] = nn.init.xavier_uniform_,
        output_activation: Optional[Callable] = nn.Identity(),
        bias_enabled: bool = True,
        bias_init: Callable = nn.init.zeros_,
        coord_dim: int = 0
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.net_depth = net_depth
        self.net_width = net_width
        self.skip_layer = skip_layer
        self.hidden_init = hidden_init
        self.hidden_activation = hidden_activation
        self.output_enabled = output_enabled
        self.output_init = output_init
        self.output_activation = output_activation
        self.bias_enabled = bias_enabled
        self.bias_init = bias_init
        
        self.coord_dim = coord_dim
        
        gr = 256
        ch = 64
        
        self.matMode = [[0,1], [0,2], [1,2]]
#         self.matMode = [[0,1], [0,2], [0,3], [1,2], [1,3], [2,3]]
        self.vecMode =  [2, 1, 0]

        self.gammas = nn.ParameterList()
        self.betas = nn.ParameterList()
        self.means = nn.ParameterList()
        self.vars = nn.ParameterList()
        self.bns = nn.ModuleList()
        self.meanvars = nn.ParameterList()
        self.gammas1 = nn.ModuleList()
        self.gammas2 = nn.ModuleList()
        self.gammas3 = nn.ModuleList()
        self.gammas4 = nn.ModuleList()
        self.gammas5 = nn.ModuleList()
        self.gammas6 = nn.ModuleList()
        
        self.hidden_layers = nn.ModuleList()
        in_features = self.input_dim
        for i in range(self.net_depth):
            self.hidden_layers.append(
                nn.Linear(in_features, self.net_width, bias=bias_enabled)
            )
#             self.gammas.append(nn.Parameter(torch.ones(1,1,gr,gr,gr)))
#             self.betas.append(nn.Parameter(torch.zeros(1,1,gr,gr,gr)))
            if self.coord_dim == 2:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((1, 1, 10, 3)), torch.zeros((1, 1, 10, 3))],dim=1)))
#                 self.bns.append(nn.BatchNorm1d(self.net_width, affine=False))
#                 self.means.append(nn.Parameter(torch.zeros(1)))
#                 self.vars.append(nn.Parameter(torch.ones(1)))
#                 self.meanvars.append(torch.nn.Parameter(torch.cat([torch.zeros((1, 1, 10, 3)), torch.ones((1, 1, 10, 3))],dim=1)))
            elif self.coord_dim == 1:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((1, 1, 5, 1)), torch.zeros((1, 1, 5, 1))],dim=1)))
            elif self.coord_dim == 3:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, gr, gr)), torch.zeros((3, 1, gr, gr))],dim=0)))
#                 self.bns.append(nn.BatchNorm1d(self.net_width, affine=False))
                self.means.append(nn.Parameter(torch.zeros(1)))
                self.vars.append(nn.Parameter(torch.ones(1)))
#                 self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, 2**(i+1), 2**(i+1))), torch.zeros((3, 1, 2**(i+1), 2**(i+1)))],dim=0)))
#                 self.index = torch.nn.Parameter(torch.empty(6,64,64,512))
#                 self.codebook = torch.nn.Parameter(torch.empty(512,32))
                
#                 per_level_scale = np.exp(
#                         (np.log(4096) - np.log(16)) / (16 - 1)
#                     ).tolist()
#                 self.gammas1.append(tcnn.Encoding(3,
#                 {
#                     "otype": "HashGrid",
#                     "n_levels": 16,
#                     "n_features_per_level": 2,
#                     "log2_hashmap_size": 19,
#                     "base_resolution": 16,
#                     "per_level_scale": per_level_scale,
#                 },))
#                 self.gammas2.append(tcnn.Encoding(3,
#                 {
#                     "otype": "HashGrid",
#                     "n_levels": 16,
#                     "n_features_per_level": 2,
#                     "log2_hashmap_size": 19,
#                     "base_resolution": 16,
#                     "per_level_scale": per_level_scale,
#                 },))
#                 self.gammas3.append(tcnn.Encoding(2,
#                 {
#                     "otype": "HashGrid",
#                     "n_levels": 16,
#                     "n_features_per_level": 2,
#                     "log2_hashmap_size": 16,
#                     "base_resolution": 16,
#                     "per_level_scale": per_level_scale,
#                 },))
#                 self.gammas4.append(tcnn.Encoding(2,
#                 {
#                     "otype": "HashGrid",
#                     "n_levels": 16,
#                     "n_features_per_level": 2,
#                     "log2_hashmap_size": 16,
#                     "base_resolution": 16,
#                     "per_level_scale": per_level_scale,
#                 },))
#                 self.gammas5.append(tcnn.Encoding(2,
#                 {
#                     "otype": "HashGrid",
#                     "n_levels": 16,
#                     "n_features_per_level": 2,
#                     "log2_hashmap_size": 16,
#                     "base_resolution": 16,
#                     "per_level_scale": per_level_scale,
#                 },))
#                 self.gammas6.append(tcnn.Encoding(2,
#                 {
#                     "otype": "HashGrid",
#                     "n_levels": 16,
#                     "n_features_per_level": 2,
#                     "log2_hashmap_size": 16,
#                     "base_resolution": 16,
#                     "per_level_scale": per_level_scale,
#                 },))
                
            elif self.coord_dim == 4:
                self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((6, 1, gr, gr)), torch.zeros((6, 1, gr, gr))],dim=0)))
#             self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((1, 1, 10, 3)), torch.zeros((1, 1, 10, 3))],dim=1)))
#             self.betas.append(nn.Parameter(torch.zeros(1,1,10,3)))
            
#             self.gammas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, gr, gr)), torch.zeros((3, 1, gr, gr))],dim=0)))
#             self.betas.append(torch.nn.Parameter(torch.cat([torch.ones((3, 1, gr, 1)), torch.zeros(3, 1, gr, 1)],dim=0)))
#             self.means.append(nn.Parameter(torch.zeros(1)))
#             self.vars.append(nn.Parameter(torch.ones(1)))
#             self.bns.append(nn.BatchNorm1d(self.net_width, affine=False))
#             self.pows.append(torch.nn.Parameter(torch.ones(1)))
            
            
            if (
                (self.skip_layer is not None)
                and (i % self.skip_layer == 0)
                and (i > 0)
            ):
                in_features = self.net_width + self.input_dim
            else:
                in_features = self.net_width
        if self.output_enabled:
            self.output_layer = nn.Linear(
                in_features, self.output_dim, bias=bias_enabled
            )
        else:
            self.output_dim = in_features

        self.initialize()

    def initialize(self):
        def init_func_hidden(m):
            if isinstance(m, nn.Linear):
                if self.hidden_init is not None:
                    self.hidden_init(m.weight)
                if self.bias_enabled and self.bias_init is not None:
                    self.bias_init(m.bias)

        self.hidden_layers.apply(init_func_hidden)
        if self.output_enabled:

            def init_func_output(m):
                if isinstance(m, nn.Linear):
                    if self.output_init is not None:
                        self.output_init(m.weight)
                    if self.bias_enabled and self.bias_init is not None:
                        self.bias_init(m.bias)

            self.output_layer.apply(init_func_output)

    def norm_(self, x, mean, var):
        return (x-mean) / (var + 1e-6)
    
    def forward(self, x, coord=None):
        inputs = x
#         if coord is not None and coord.shape[-1]==2:
# #                     x = self.bns[i](x)
# #                     x = self.norm_(x, self.means[i], self.vars[i])
# #             x = F.group_norm(x,1)
#             affine = F.grid_sample(self.gammas[0], coord.view(1,1,-1,2), align_corners=True).view(2,283,-1).permute(0,2,1)
#             x = affine[0]*x + affine[1]
        for i in range(self.net_depth):
            x = self.hidden_layers[i](x)
#             if coord is not None and i == self.net_depth-1:
            if coord is not None:# and i == self.net_depth-1:
                if coord.shape[-1]==2:
#                     x = self.bns[i](x)
#                     x = self.norm_(x, self.means[i], self.vars[i])
#                     x = F.group_norm(x,1)
                    affine = F.grid_sample(self.gammas[i], coord.view(1,1,-1,2), align_corners=True).view(2,-1,1)#.permute(0,2,1)
#                     meanvars = F.grid_sample(self.meanvars[i], coord.view(1,1,-1,2), align_corners=True).view(2,-1,1)
#                     x = self.norm_(x, meanvars[0], meanvars[1])
                    x = affine[0]*x + affine[1]
#                     pass
                elif coord.shape[-1]==4:
                    coordinate_plane = torch.stack((coord[..., self.matMode[0]], coord[..., self.matMode[1]], coord[..., self.matMode[2]],coord[..., self.matMode[3]], coord[..., self.matMode[4]], coord[..., self.matMode[5]])).detach().view(6, -1, 1, 2).repeat(2,1,1,1)
                    plane_feats = F.grid_sample(self.gammas[i], coordinate_plane, align_corners=True)
                    gamma = torch.prod(plane_feats[0:6],dim=0,keepdim=False).view(-1,1) #squeeze().permute(1,0)
                    beta = torch.prod(plane_feats[6:],dim=0,keepdim=False).view(-1,1) #squeeze().permute(1,0)
                    x = gamma*x# + beta
                elif coord.shape[-1]==1:
#                     x = F.group_norm(x,1)
                    coord_ = torch.cat((torch.zeros_like(coord),coord), dim=-1)
#                     print(coord.shape)
#                     affine = F.grid_sample(self.gammas[i], coord_.view(1,-1,1,2), align_corners=True).view(2,256,-1).permute(0,2,1)
                    affine = F.grid_sample(self.gammas[i], coord_.view(1,-1,1,2), align_corners=True).view(2,-1,1)
                    x = affine[0]*x + affine[1]
                else:
#                     ii = torch.argmax(self.index,dim=-1)
#                     regrid = self.codebook[ii].permute(0,3,1,2)
#                     print(regrid.shape)
#                     gamma = self.gammas1[i](coord[..., self.matMode[0]]).float()*self.gammas2[i](coord[..., self.matMode[1]]).float()*self.gammas3[i](coord[..., self.matMode[2]]).float()
#                     beta = self.gammas4[i](coord[..., self.matMode[0]]).float()*self.gammas5[i](coord[..., self.matMode[1]]).float()*self.gammas6[i](coord[..., self.matMode[2]]).float()
#                     print(i, gamma.shape, beta.shape, x.shape)
#                     x = gamma*x+beta
#                     print(x.shape)
#                     gamma = self.gammas1[i](coord).permute(1,0).reshape(2,-1,1).float()
#                     x = gamma[0]*x + gamma[1]
#                     x = self.gammas1[i](coord).float()*x + self.gammas2[i](coord).float()
#                 print(coord)
#                     x = F.group_norm(x,1)
#                     x = F.normalize(x)
#                     x = self.norm_(x, self.means[i], self.vars[i])
#                     x = self.bns[i](x)
#                 x = F.layer_norm(x, [x.shape[-1]])
#                 x = F.grid_sample(self.gammas[i], coord.view(1,1,1,-1,3), align_corners=True).view(-1,1)*x 
#                 x = x + F.grid_sample(self.betas[i], coord.view(1,1,1,-1,3), align_corners=True).view(-1,1)

#                 x = F.grid_sample(self.gammas[i], coord.view(1,1,1,-1,3), align_corners=True).squeeze().permute(1,0)*x 
#                 x = x + F.grid_sample(self.betas[i], coord.view(1,1,1,-1,3), align_corners=True).squeeze().permute(1,0)
#                 x = F.grid_sample(self.gammas[i], coord.view(1,1,-1,2), align_corners=True).view(-1,1)*x 
#                 x = x + F.grid_sample(self.betas[i], coord.view(1,1,-1,2), align_corners=True).view(-1,1)
#                 coord = torch.pow(coord, self.pows[i])

                    coordinate_plane = torch.stack((coord[..., self.matMode[0]], coord[..., self.matMode[1]], coord[..., self.matMode[2]])).detach().view(3, -1, 1, 2).repeat(2,1,1,1)
    #                 coordinate_plane = torch.stack((coord[..., self.matMode[0]], coord[..., self.matMode[1]], coord[..., self.matMode[2]],coord[..., self.matMode[3]], coord[..., self.matMode[4]], coord[..., self.matMode[5]])).detach().view(6, -1, 1, 2).repeat(2,1,1,1)
    #                 coordinate_line = torch.stack((coord[..., self.vecMode[0]], coord[..., self.vecMode[1]], coord[..., self.vecMode[2]]))
    #                 coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2).repeat(2,1,1,1)

#                     plane_feats = F.grid_sample(regrid, coordinate_plane, align_corners=True)
                    plane_feats = F.grid_sample(self.gammas[i], coordinate_plane, align_corners=True)
    #                 line_feats = F.grid_sample(self.betas[i], coordinate_line, align_corners=True)
                    
#                     gamma = torch.prod(plane_feats[0:3],dim=0,keepdim=False).view(x.size(-1),-1).permute(1,0)
#                     beta = torch.prod(plane_feats[3:],dim=0,keepdim=False).view(x.size(-1),-1).permute(1,0)
                    gamma = torch.prod(plane_feats[0:3],dim=0,keepdim=False).view(-1,1) #squeeze().permute(1,0)
                    beta = torch.prod(plane_feats[3:],dim=0,keepdim=False).view(-1,1) #squeeze().permute(1,0)
                    x = gamma*x + beta
#                 print(coord.shape)
#                 affine = F.grid_sample(self.gammas[i], coord.view(1,1,-1,2), align_corners=True).view(2,-1,1)
#                 x = affine[0]*x + affine[1]
            
            x = self.hidden_activation(x)
            if (
                (self.skip_layer is not None)
                and (i % self.skip_layer == 0)
                and (i > 0)
            ):
                x = torch.cat([x, inputs], dim=-1)
        if self.output_enabled:
            x = self.output_layer(x)
            x = self.output_activation(x)
        return x


class DenseLayer(MLP):
    def __init__(self, input_dim, output_dim, **kwargs):
        super().__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            net_depth=0,  # no hidden layers
            **kwargs,
        )


class NerfMLP(nn.Module):
    def __init__(
        self,
        input_dim: int,  # The number of input tensor channels.
        condition_dim: int,  # The number of condition tensor channels.
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        net_depth_condition: int = 1,  # The depth of the second part of MLP.
        net_width_condition: int = 128,  # The width of the second part of MLP.
    ):
        super().__init__()
        self.base = MLP(
            input_dim=input_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            output_enabled=False,
            coord_dim=1,
        )
        hidden_features = self.base.output_dim
        self.sigma_layer = DenseLayer(hidden_features, 1)
        
        self.peri = torch.nn.Parameter(torch.cat([torch.ones((1, 256, 10, 3)), torch.zeros((1, 256, 10, 3))],dim=1))

        if condition_dim > 0:
            self.bottleneck_layer = DenseLayer(hidden_features, net_width)
            self.rgb_layer = MLP(
                input_dim=net_width + condition_dim,
                output_dim=3,
                net_depth=net_depth_condition,
                net_width=net_width_condition,
                skip_layer=None,
                coord_dim=2,
            )
        else:
            self.rgb_layer = DenseLayer(hidden_features, 3)

    def query_density(self, x, coord=None):
        x = self.base(x, coord[0])
#         x = self.base(x, None)
        raw_sigma = self.sigma_layer(x)
        return raw_sigma

    def forward(self, x, condition=None, coord=None):
        x = self.base(x, coord[0])
#         x = self.base(x, None)
#         x = self.base(x)
        raw_sigma = self.sigma_layer(x)
    
#         x = F.group_norm(x,1)
# #         x = F.normalize(x)
#         affine = F.grid_sample(self.peri, cart2sph(condition).view(1,1,-1,2), align_corners=True).view(2,256,-1).permute(0,2,1)
#         x = affine[0]*x + affine[1]
        
        if condition is not None:
            if condition.shape[:-1] != x.shape[:-1]:
                num_rays, n_dim = condition.shape
                condition = condition.view(
                    [num_rays] + [1] * (x.dim() - condition.dim()) + [n_dim]
                ).expand(list(x.shape[:-1]) + [n_dim])
            bottleneck = self.bottleneck_layer(x)
            x = torch.cat([bottleneck, condition], dim=-1)
        raw_rgb = self.rgb_layer(x, None)#coord[1])
        return raw_rgb, raw_sigma


class SinusoidalEncoder(nn.Module):
    """Sinusoidal Positional Encoder used in Nerf."""

    def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True):
        super().__init__()
        self.x_dim = x_dim
        self.min_deg = min_deg
        self.max_deg = max_deg
        self.use_identity = use_identity
        self.register_buffer(
            "scales", torch.tensor([2**i for i in range(min_deg, max_deg)])
        )

    @property
    def latent_dim(self) -> int:
        return (
            int(self.use_identity) + (self.max_deg - self.min_deg) * 2
        ) * self.x_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [..., x_dim]
        Returns:
            latent: [..., latent_dim]
        """
        if self.max_deg == self.min_deg:
            return x
        xb = torch.reshape(
            (x[Ellipsis, None, :] * self.scales[:, None]),
            list(x.shape[:-1]) + [(self.max_deg - self.min_deg) * self.x_dim],
        )
        latent = torch.sin(torch.cat([xb, xb + 0.5 * math.pi], dim=-1))
        if self.use_identity:
            latent = torch.cat([x] + [latent], dim=-1)
        return latent

class VanillaNeRFRadianceField(nn.Module):
    def __init__(
        self,
        net_depth: int = 8,  # The depth of the MLP.
        net_width: int = 256,  # The width of the MLP.
        skip_layer: int = 4,  # The layer to add skip layers to.
        net_depth_condition: int = 1,  # The depth of the second part of MLP.
        net_width_condition: int = 128,  # The width of the second part of MLP.
        aabb=None
    ) -> None:
        super().__init__()
        self.posi_encoder = SinusoidalEncoder(3, 0, 10, True)
        self.view_encoder = SinusoidalEncoder(3, 0, 4, True)
        self.mlp = NerfMLP(
            input_dim=self.posi_encoder.latent_dim,
            condition_dim=self.view_encoder.latent_dim,
            net_depth=net_depth,
            net_width=net_width,
            skip_layer=skip_layer,
            net_depth_condition=net_depth_condition,
            net_width_condition=net_width_condition,
        )
        
#         self.coordnet = MLP(
#             input_dim=self.posi_encoder.latent_dim,
#             output_dim=3,
#             net_depth=4,
#             net_width=64,
#             skip_layer=2,
#             output_init=functools.partial(torch.nn.init.uniform_, b=1e-4),
#         )
#         self.coordnet = tcnn.Encoding(3,
#             {
#                 "otype": "Grid",
#                 "type": "Dense",
#                 "n_levels": 1,
#                 "n_features_per_level": 4,
#                 "base_resolution": 128,
#             },)
#         self.coordnet = VectorQuantizer2(8192, 3, beta=0.25,remap=None, sane_index_shape=False)
#         self.peri = torch.nn.Parameter(torch.cat([torch.ones((3,1,512,1)), torch.zeros((3,1,512,1))],dim=1))
        self.peri = torch.nn.Parameter(torch.zeros((1,3,64,64,64)))
#         self.coordnet = nn.Sequential(nn.Linear(3,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,128), nn.Tanh(), nn.Linear(128,3), nn.Tanh())
#         self.coordnet = SirenNet(dim_in = 3, dim_hidden = 128, dim_out = 3, num_layers = 6, final_activation = nn.Tanh(), w0_initial = 2.)
        self.qloss = 0
        self.aabb = aabb

    def query_opacity(self, x, step_size):
        density = self.query_density(x)
        # if the density is small enough those two are the same.
        # opacity = 1.0 - torch.exp(-density * step_size)
        opacity = density * step_size
        return opacity

    def query_density(self, x, coord=None):
#         if self.aabb is not None:
# #             x = contract_to_unisphere(x, self.aabb)
#             aabb_min, aabb_max = torch.split(self.aabb, 3, dim=-1)
#             x = (x - aabb_min) / (aabb_max - aabb_min)
#         coord = x.clone()*2.0/3.0
#         coord = cart2sph(coord)
#         print(torch.amax(coord, dim=0), torch.amin(coord, dim=0))
#         coord = self.coordnet(x)
#         coord = x.clone()
#         coord = (x.clone()+1.5)/3.0
#         x = self.posi_encoder(x)
#         coord1, _, _ = self.coordnet(x)
#         peri = F.grid_sample(self.peri, torch.cat([coord.permute(1,0).view(3,-1,1,1),torch.zeros((3,coord.size(0),1,1)).cuda()],dim=-1)).squeeze(-1).permute(1,2,0)
#         peri = F.grid_sample(self.peri, coord.view(1,1,1,-1,3)).view(3,-1).permute(1,0)
#         coord = torch.sin(peri)
#         coord = peri
        x = self.posi_encoder(x)
        sigma = self.mlp.query_density(x, coord=(coord,None))
        return F.relu(sigma)

    def forward(self, x, condition=None, coord=None):
#         if self.aabb is not None:
# #             x = contract_to_unisphere(x, self.aabb)
#             aabb_min, aabb_max = torch.split(self.aabb, 3, dim=-1)
#             x = (x - aabb_min) / (aabb_max - aabb_min) *2.0 - 1.0
#         print(torch.min(x, dim=0), torch.max(x, dim=0))
#         coord = x.clone()
#         coord = x.clone()*2.0/3.0
#         print(torch.amin(torch.amin(coord, dim=0),dim=0),torch.amax(torch.amax(coord, dim=0),dim=0))
#         coord = (x.clone()+1.5)/3.0
#         coord1 = self.coordnet(x)
#         coord = cart2sph(condition)
#         print("viewdir", cart2sph(condition))
#         print("coord", cart2sph(x))
#         x = self.posi_encoder(x)
#         coord1, loss, _ = self.coordnet(x)
#         self.q_loss = loss
#         peri = F.grid_sample(self.peri, torch.cat([coord.permute(1,0).view(3,-1,1,1),torch.zeros((3,coord.size(0),1,1)).cuda()],dim=-1)).squeeze(-1).permute(1,2,0)
#         print(self.peri.shape)
#         self.qloss = torch.var(self.peri)
#         print(self.peri.shape)
#         peri = F.grid_sample(self.peri, coord.view(1,1,1,-1,3)).view(3,-1).permute(1,0)
#         print(peri.shape)
#         coord = peri
        x = self.posi_encoder(x)
#         coord1 = self.coordnet(x)
        if condition is not None:
#             print(torch.min(condition, dim=0), torch.max(condition, dim=0))
            coord2 = cart2sph(condition.clone())
#             print(torch.min(cart2sph(condition), dim=0), torch.max(cart2sph(condition), dim=0))
            condition = self.view_encoder(condition)
        rgb, sigma = self.mlp(x, condition=condition, coord=(coord, coord2))
        return torch.sigmoid(rgb), F.relu(sigma)


class TNeRFRadianceField(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.posi_encoder = SinusoidalEncoder(3, 0, 4, True)
        self.time_encoder = SinusoidalEncoder(1, 0, 4, True)
        self.warp = MLP(
            input_dim=self.posi_encoder.latent_dim
            + self.time_encoder.latent_dim,
            output_dim=3,
            net_depth=4,
            net_width=64,
            skip_layer=2,
            output_init=functools.partial(torch.nn.init.uniform_, b=1e-4),
        )
        self.nerf = VanillaNeRFRadianceField()
        
    def query_opacity(self, x, timestamps, step_size):

        idxs = torch.randint(0, len(timestamps), (x.shape[0],), device=x.device)
        t = timestamps[idxs]
        density = self.query_density(x, t)
        # if the density is small enough those two are the same.
        # opacity = 1.0 - torch.exp(-density * step_size)
        opacity = density * step_size
        return opacity

    def query_density(self, x, t):
        coord = 2*t-1.0
#         coord = None
#         coord = torch.cat([x.clone()*2.0/3.0, 2*t.clone()-1.0], dim=-1)
        x = x + self.warp(
            torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1), coord=None
        )
        return self.nerf.query_density(x, coord=coord)

    def forward(self, x, t, condition=None):
        coord = 2*t-1.0
#         print(x.shape, t.shape)
#         coord = None
#         coord = torch.cat([x.clone()*2.0/3.0, 2.0*t.clone()-1.0], dim=-1)
        x = x + self.warp(
            torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1), coord=None
        )
        return self.nerf(x, condition=condition, coord=coord)

    
class VectorQuantizer2(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
    avoids costly matrix multiplications and allows for post-hoc remapping of indices.
    """
    # NOTE: due to a bug the beta term was applied to the wrong term. for
    # backwards compatibility we use the buggy version by default, but you can
    # specify legacy=False to fix it.
    def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
                 sane_index_shape=False, legacy=True):
        super().__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta
        self.legacy = legacy

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        self.remap = remap
        if self.remap is not None:
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
            self.re_embed = self.used.shape[0]
            self.unknown_index = unknown_index # "random" or "extra" or integer
            if self.unknown_index == "extra":
                self.unknown_index = self.re_embed
                self.re_embed = self.re_embed+1
            print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
                  f"Using {self.unknown_index} for unknown indices.")
        else:
            self.re_embed = n_e

        self.sane_index_shape = sane_index_shape

    def remap_to_used(self, inds):
        ishape = inds.shape
        assert len(ishape)>1
        inds = inds.reshape(ishape[0],-1)
        used = self.used.to(inds)
        match = (inds[:,:,None]==used[None,None,...]).long()
        new = match.argmax(-1)
        unknown = match.sum(2)<1
        if self.unknown_index == "random":
            new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
        else:
            new[unknown] = self.unknown_index
        return new.reshape(ishape)

    def unmap_to_all(self, inds):
        ishape = inds.shape
        assert len(ishape)>1
        inds = inds.reshape(ishape[0],-1)
        used = self.used.to(inds)
        if self.re_embed > self.used.shape[0]: # extra token
            inds[inds>=self.used.shape[0]] = 0 # simply set to zero
        back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
        return back.reshape(ishape)

    def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
        assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
        assert rescale_logits==False, "Only for interface compatible with Gumbel"
        assert return_logits==False, "Only for interface compatible with Gumbel"
        # reshape z -> (batch, height, width, channel) and flatten
#         z = rearrange(z, 'b c h w -> b h w c').contiguous()
        z_flattened = z.view(-1, self.e_dim)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        
        print(z_flattened.shape)
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices) #.view(z.shape)
        perplexity = None
        min_encodings = None

        # compute loss for embedding
        if not self.legacy:
            loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
                   torch.mean((z_q - z.detach()) ** 2)
        else:
            loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
                   torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # reshape back to match original input shape
#         z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()

        if self.remap is not None:
            min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
            min_encoding_indices = self.remap_to_used(min_encoding_indices)
            min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten

        if self.sane_index_shape:
            min_encoding_indices = min_encoding_indices.reshape(
                z_q.shape[0], z_q.shape[2], z_q.shape[3])

        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

    def get_codebook_entry(self, indices, shape):
        # shape specifying (batch, height, width, channel)
        if self.remap is not None:
            indices = indices.reshape(shape[0],-1) # add batch axis
            indices = self.unmap_to_all(indices)
            indices = indices.reshape(-1) # flatten again

        # get quantized latent vectors
        z_q = self.embedding(indices)

        if shape is not None:
            z_q = z_q.view(shape)
            # reshape back to match original input shape
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q