import os
import time
import functools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

from torch_scatter import segment_coo

from torch.utils.cpp_extension import load
parent_dir = os.path.dirname(os.path.abspath(__file__))
render_utils_cuda = load(
        name='render_utils_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/render_utils.cpp', 'cuda/render_utils_kernel.cu']],
        verbose=True)

total_variation_cuda = load(
        name='total_variation_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/total_variation.cpp', 'cuda/total_variation_kernel.cu']],
        verbose=True)

import lib.networks as networks
from lib_extra.attention import TransformerBlock2 as TransformerBlock, SlotAttention2
# from lib_extra.GCN import GCN
from lib_extra.network import GaussianStateInit

import pdb

'''Model'''
class VoxelMlp(torch.nn.Module):
    def __init__(self, xyz_min, xyz_max,
                 num_voxels=0, num_voxels_base=0,
                 alpha_init=None,
                 mask_cache_path=None, mask_cache_thres=1e-3,
                 fast_color_thres=0,
                 rgbnet_dim=0, rgbnet_direct=False, rgbnet_full_implicit=False,
                 rgbnet_depth=3, rgbnet_width=128,
                 viewbase_pe=4,
                 **kwargs):
        super(VoxelMlp, self).__init__()
        print('Slots Decouple without slots_m.')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.register_buffer('xyz_min', torch.Tensor(xyz_min))
        self.register_buffer('xyz_max', torch.Tensor(xyz_max))
        self.fast_color_thres = fast_color_thres

        # determine based grid resolution
        self.num_voxels_base = num_voxels_base
        self.voxel_size_base = ((self.xyz_max - self.xyz_min).prod() / self.num_voxels_base).pow(1/3)

        # determine the density bias shift
        self.alpha_init = alpha_init
        self.act_shift = np.log(1/(1-alpha_init) - 1)
        print('voxelMlp: set density bias shift to', self.act_shift)

        # determine init grid resolution
        self._set_grid_resolution(num_voxels)

        # init density voxel grid
        self.density = torch.nn.Parameter(torch.zeros([1, kwargs['max_instances'], *self.world_size]))

        self.decoder = networks.init_net(Decoder(n_freq=kwargs['n_freq'], n_freq_view=kwargs['n_freq_view'], input_dim=kwargs['n_freq']*6+3+kwargs['z_dim'], 
                               input_ch_dim=6*kwargs['n_freq_view']+3, z_dim=kwargs['z_dim'], n_layers=kwargs['n_layers'], out_ch=kwargs['out_ch']))
        
         # local dynamics -- w/ slots
        self._time, self._time_out = self.create_time_net(input_dim=kwargs['n_freq_t']*6+3,
                                                          input_dim_time=kwargs['n_freq_time']*2+1, D=kwargs['timenet_layers'], W=kwargs['timenet_hidden'], skips=kwargs['skips'])
        self.skips = kwargs['skips']
        self.n_freq_t = kwargs['n_freq_t']
        self.n_freq_time = kwargs['n_freq_time']


        # slots initialization
        self.num_slots = kwargs['max_instances']
        self.slots_o = torch.randn(1, self.num_slots, kwargs['z_dim'])

        self.z_dim = kwargs['z_dim']
        # self.m_dim = kwargs['m_dim']

        # self.slots_m = torch.nn.Parameter(torch.randn(1, self.num_slots, kwargs['m_dim'])) # motion information, optimize
        # self.slots_m_updated = torch.randn(1, self.num_slots, kwargs['m_dim'])


        if 'last_episode_o' in kwargs.keys():
            self.last_episode_o = kwargs['last_episode_o']
        else:
            self.last_episode_o = torch.zeros_like(self.slots_o)

        if 'curr_episode_o' in kwargs.keys():
            self.curr_episode_o_episode = kwargs['curr_episode_o']
        else:
            self.curr_episode_o_episode = torch.zeros((kwargs['timesteps'], 1, self.num_slots, kwargs['z_dim']))

        self.slot_attention = SlotAttention2(
            voxel_dim=kwargs['max_instances'],
            in_dim=kwargs['encoder_dim'],
            slot_dim=kwargs['z_dim'],
            iters=kwargs['num_iterations'],
            hidden_dim=kwargs['hidden'],
            kernel_size=kwargs['kernel_size'],
            stride=kwargs['stride'],
        )

        # self.predictor = TransformerBlock(
		# 	embed_dim=kwargs['z_dim']+kwargs['m_dim'],
		# 	num_heads=2,
		# 	qkv_size=kwargs['z_dim'],
		# 	mlp_size=kwargs['z_dim']*2,
        #     op_size=kwargs['m_dim'],
		# 	weight_init=kwargs['weight_init'])

        self.kwargs = kwargs
        self.last_timestep = -1

        self.mask_cache = None
        self.nonempty_mask = None 
        self.mask_cache_path = mask_cache_path
        self.mask_cache_thres = mask_cache_thres

    def create_time_net(self, input_dim, input_dim_time, D, W, skips, memory=[]):
        layers = [nn.Linear(input_dim + input_dim_time, W)]
        for i in range(D - 1):
            if i in memory:
                raise NotImplementedError
            else:
                layer = nn.Linear

            in_channels = W
            if i in skips:
                in_channels += input_dim

            layers += [layer(in_channels, W)]
        return nn.ModuleList(layers), nn.Linear(W, 3)

    def query_time(self, new_pts, t, net, net_final, pdb_flag=0):
        # if pdb_flag == 1:
        #     pdb.set_trace()
        pts_sim = sin_emb(new_pts, n_freq=self.n_freq_t)
        t_sim = sin_emb(t.expand([new_pts.shape[0], 1]), n_freq=self.n_freq_time)
        h = torch.cat([pts_sim, t_sim], dim=-1)
        for i, l in enumerate(net):
            h = net[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([pts_sim, h], -1)

        return net_final(h)

    def _set_grid_resolution(self, num_voxels):
        # Determine grid resolution
        # pdb.set_trace()
        self.num_voxels = num_voxels
        self.voxel_size = ((self.xyz_max - self.xyz_min).prod() / num_voxels).pow(1/3)
        self.world_size = ((self.xyz_max - self.xyz_min) / self.voxel_size).long()
        self.voxel_size_ratio = self.voxel_size / self.voxel_size_base
        # define the xyz positions of grid
        self.world_pos = torch.zeros([1, 3, *self.world_size])
        xcoord = torch.linspace(start=self.xyz_min[0], end=self.xyz_max[0], steps=self.world_size[0])
        ycoord = torch.linspace(start=self.xyz_min[1], end=self.xyz_max[1], steps=self.world_size[1])
        zcoord = torch.linspace(start=self.xyz_min[2], end=self.xyz_max[2], steps=self.world_size[2])
        grid = torch.meshgrid(xcoord, ycoord, zcoord)
        for i in range(3):
            self.world_pos[0, i, :, :, :] = grid[i]    
        print('voxelMlp: voxel_size      ', self.voxel_size)
        print('voxelMlp: world_size      ', self.world_size)
        print('voxelMlp: voxel_size_base ', self.voxel_size_base)
        print('voxelMlp: voxel_size_ratio', self.voxel_size_ratio)

    def get_kwargs(self):
        return {
            'xyz_min': self.xyz_min.cpu().numpy(),
            'xyz_max': self.xyz_max.cpu().numpy(),
            'num_voxels': self.num_voxels,
            'num_voxels_base': self.num_voxels_base,
            'alpha_init': self.alpha_init,
            'act_shift': self.act_shift,
            'voxel_size_ratio': self.voxel_size_ratio,
            'mask_cache_path': self.mask_cache_path,
            'mask_cache_thres': self.mask_cache_thres,
            'fast_color_thres': self.fast_color_thres,
            'n_freq': self.kwargs['n_freq'],
            'n_freq_view': self.kwargs['n_freq_view'],
            'n_freq_dynamics': self.kwargs['n_freq_dynamics'],
            'z_dim': self.kwargs['z_dim'],
            'm_dim': self.kwargs['m_dim'],
            'hidden': self.kwargs['hidden'],
            'n_layers': self.kwargs['n_layers'],
            "out_ch": self.kwargs['out_ch'],
            "max_instances": self.kwargs['max_instances'],
            "dropout": self.kwargs['dropout'],
            "encoder_dim": self.kwargs['encoder_dim'],
            "num_iterations": self.kwargs['num_iterations'],
            "weight_init": self.kwargs['weight_init'],
            "kernel_size": self.kwargs['kernel_size'],
            "stride": self.kwargs['stride'],
            "n_freq_t": self.kwargs['n_freq_t'],
            "n_freq_time": self.kwargs['n_freq_time'],
            "timesteps": self.kwargs['timesteps'],
            "timenet_layers": self.kwargs['timenet_layers'],
            "timenet_hidden": self.kwargs['timenet_hidden'],
            "skips": self.kwargs['skips'],
            "last_episode_o": torch.mean(self.curr_episode_o_episode, dim=0)
            # "last_episode_o": self.last_episode_o,
            # "curr_episode_o": self.curr_episode_o_episode
            # "curr_episode_o": self.curr_episode_o
            # **self.rgbnet_kwargs,
        }

    @torch.no_grad()
    def maskout_near_cam_vox(self, cam_o, near):
        self_grid_xyz = torch.stack(torch.meshgrid(
            torch.linspace(self.xyz_min[0], self.xyz_max[0], self.density.shape[2]),
            torch.linspace(self.xyz_min[1], self.xyz_max[1], self.density.shape[3]),
            torch.linspace(self.xyz_min[2], self.xyz_max[2], self.density.shape[4]),
        ), -1)
        nearest_dist = torch.stack([
            (self_grid_xyz.unsqueeze(-2) - co).pow(2).sum(-1).sqrt().amin(-1)
            for co in cam_o.split(100)  # for memory saving
        ]).amin(0)

        nearest_dist = nearest_dist[None, None].expand(-1, self.density.shape[1], -1, -1, -1)
        # self.density[nearest_dist[None,None] <= near] = -100
        self.density[nearest_dist <= near] = -100

    @torch.no_grad()
    def scale_volume_grid(self, num_voxels):
        print('voxelMlp: scale_volume_grid start')
        ori_world_size = self.world_size
        self._set_grid_resolution(num_voxels)
        print('voxelMlp: scale_volume_grid scale world_size from', ori_world_size, 'to', self.world_size)

        self.density = torch.nn.Parameter(
            F.interpolate(self.density.data, size=tuple(self.world_size), mode='trilinear', align_corners=True))

        mask_cache = MaskCache(
                path=self.mask_cache_path,
                mask_cache_thres=self.mask_cache_thres).to(self.xyz_min.device)
        self_grid_xyz = torch.stack(torch.meshgrid(
            torch.linspace(self.xyz_min[0], self.xyz_max[0], self.density.shape[2]),
            torch.linspace(self.xyz_min[1], self.xyz_max[1], self.density.shape[3]),
            torch.linspace(self.xyz_min[2], self.xyz_max[2], self.density.shape[4]),
        ), -1)
        self_alpha = F.max_pool3d(self.activate_density(self.density), kernel_size=3, padding=1, stride=1)[0,0]
        self.mask_cache = MaskCache(
                path=None, mask=mask_cache(self_grid_xyz) & (self_alpha>self.fast_color_thres),
                xyz_min=self.xyz_min, xyz_max=self.xyz_max)

        print('voxelMlp: scale_volume_grid finish')

    def voxel_count_views(self, rays_o_tr, rays_d_tr, imsz, near, far, stepsize, downrate=1, irregular_shape=False):
        print('voxelMlp: voxel_count_views start')
        eps_time = time.time()
        N_samples = int(np.linalg.norm(np.array(self.density.shape[2:])+1) / stepsize) + 1
        rng = torch.arange(N_samples)[None].float()
        count = torch.zeros_like(self.density.detach())
        device = rng.device
        for rays_o_, rays_d_ in zip(rays_o_tr.split(imsz), rays_d_tr.split(imsz)):
            ones = torch.ones_like(self.density).requires_grad_()
            if irregular_shape:
                rays_o_ = rays_o_.split(10000)
                rays_d_ = rays_d_.split(10000)
            else:
                rays_o_ = rays_o_[::downrate, ::downrate].to(device).flatten(0,-2).split(10000)
                rays_d_ = rays_d_[::downrate, ::downrate].to(device).flatten(0,-2).split(10000)

            for rays_o, rays_d in zip(rays_o_, rays_d_):
                vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)
                rate_a = (self.xyz_max - rays_o) / vec
                rate_b = (self.xyz_min - rays_o) / vec
                t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
                t_max = torch.maximum(rate_a, rate_b).amin(-1).clamp(min=near, max=far)
                step = stepsize * self.voxel_size * rng
                interpx = (t_min[...,None] + step/rays_d.norm(dim=-1,keepdim=True))
                rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
                self.grid_sampler(rays_pts, ones).sum().backward()
            with torch.no_grad():
                count += (ones.grad > 1)
        eps_time = time.time() - eps_time
        print('voxelMlp: voxel_count_views finish (eps time:', eps_time, 'sec)')
        return count

    def density_total_variation_add_grad(self, weight, dense_mode):
        weight = weight * self.world_size.max() / 128
        total_variation_cuda.total_variation_add_grad(
            self.density, self.density.grad, weight, weight, weight, dense_mode)

    def activate_density(self, density, interval=None):
        interval = interval if interval is not None else self.voxel_size_ratio
        shape = density.shape
        return Raw2Alpha.apply(density.flatten(), self.act_shift, interval).reshape(shape)

    def activate_density_multiple(self, density, interval=None, dens_noise=0):
        interval = interval if interval is not None else self.voxel_size_ratio
        raw_masks = F.softplus(density + self.act_shift, True)

        raw_sigma = raw_masks + dens_noise * torch.randn_like(raw_masks)

        masks = raw_masks / (raw_masks.sum(dim=-1)[:,None] + 1e-5)  # PxK

        sigma = (raw_sigma * masks).sum(dim=-1)
        alpha = 1 - torch.exp(-sigma * interval)
        return alpha
 
    def grid_sampler(self, xyz, *grids, mode=None, align_corners=True):
        '''Wrapper for the interp operation'''
        mode = 'bilinear'
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1,1,1,-1,3)
        ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1
        ret_lst = [
            # TODO: use `rearrange' to make it readable
            F.grid_sample(grid, ind_norm, mode=mode, align_corners=align_corners).reshape(grid.shape[1],-1).T.reshape(*shape,grid.shape[1])
            for grid in grids
        ]
        for i in range(len(grids)):
            if ret_lst[i].shape[-1] == 1:
                ret_lst[i] = ret_lst[i].squeeze(-1)
        if len(ret_lst) == 1:
            return ret_lst[0]
        return ret_lst

    def hit_coarse_geo(self, rays_o, rays_d, near, far, stepsize, **render_kwargs):
        '''Check whether the rays hit the solved coarse geometry or not'''
        shape = rays_o.shape[:-1]
        rays_o = rays_o.reshape(-1, 3).contiguous()
        rays_d = rays_d.reshape(-1, 3).contiguous()
        stepdist = stepsize * self.voxel_size
        ray_pts, mask_outbbox, ray_id = render_utils_cuda.sample_pts_on_rays(
                rays_o, rays_d, self.xyz_min, self.xyz_max, near, far, stepdist)[:3]
        mask_inbbox = ~mask_outbbox
        hit = torch.zeros([len(rays_o)], dtype=torch.bool)
        hit[ray_id[mask_inbbox][self.mask_cache(ray_pts[mask_inbbox])]] = 1
        return hit.reshape(shape)

    def sample_ray(self, rays_o, rays_d, near, far, stepsize, is_train=False, **render_kwargs):
        '''Sample query points on rays.
        All the output points are sorted from near to far.
        Input:
            rays_o, rayd_d:   both in [N, 3] indicating ray configurations.
            near, far:        the near and far distance of the rays.
            stepsize:         the number of voxels of each sample step.
        Output:
            ray_pts:          [M, 3] storing all the sampled points.
            ray_id:           [M]    the index of the ray of each point.
            step_id:          [M]    the i'th step on a ray of each point.
        '''
        rays_o = rays_o.contiguous()
        rays_d = rays_d.contiguous()
        stepdist = stepsize * self.voxel_size
        ray_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max = render_utils_cuda.sample_pts_on_rays(
            rays_o, rays_d, self.xyz_min, self.xyz_max, near, far, stepdist)
        mask_inbbox = ~mask_outbbox
        ray_pts = ray_pts[mask_inbbox]
        ray_id = ray_id[mask_inbbox]
        step_id = step_id[mask_inbbox]
        return ray_pts, ray_id, step_id

    def update_density(self, frame_time):
        ray_pts = self.world_pos[0].flatten(start_dim=1).permute(1,0)
        dx = self.query_time(ray_pts, frame_time, self._time, self._time_out)
        density = self.grid_sampler(ray_pts+dx, self.density).permute(1,0).reshape(self.density.shape)
        return density

    def forward(self, rays_o, rays_d, viewdirs, frame_time, time_index, global_step=None, start=False, training_flag=True, first_episode=False, stc_data=False, **render_kwargs):
        '''Volume rendering
        @rays_o:   [N, 3] the starting point of the N shooting rays.
        @rays_d:   [N, 3] the shooting direction of the N rays.
        @viewdirs: [N, 3] viewing direction to compute positional embedding for MLP.
        '''
        assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format'

        ret_dict = {}
        N = len(rays_o)

        # sample points on rays
        ray_pts, ray_id, step_id = self.sample_ray(
            rays_o=rays_o, rays_d=rays_d, is_train=global_step is not None, **render_kwargs)
        interval = render_kwargs['stepsize'] * self.voxel_size_ratio

        if stc_data:
            pdb.set_trace()
            # slots = self.slots_get()
            # slots_updated, attn = self.slot_attention(self.slots_o, self.density.detach())
            # density = self.grid_sampler(ray_pts, self.density)
            # ray_pts_ = ray_pts
        
        # dynamics data
        else:
            if self.last_timestep != time_index:
                self.update_flag = 1
                self.last_timestep = time_index
            else:
                self.update_flag = 0


            # pdb.set_trace()
            if training_flag or self.update_flag:
                dx = self.query_time(ray_pts, frame_time, self._time, self._time_out)
                ray_pts_ = ray_pts + dx
                
                dynamics_density = self.update_density(frame_time)
                slots_updated, attn = self.slot_attention(self.slots_o, dynamics_density.detach())
                density = self.grid_sampler(ray_pts_, self.density)
                self.slots = slots_updated.detach()

                if training_flag:
                    if start:
                        self.last_episode_o = torch.mean(self.curr_episode_o_episode, dim=0).detach()
                    self.curr_episode_o_episode[time_index] = slots_updated
         
                # if start:
                #     slots_updated, attn = self.slot_attention(self.slots_o, self.density.detach())
                #     density = self.grid_sampler(ray_pts, self.density)
                #     ray_pts_ = ray_pts
                #     # self.slots_m_updated = self.slots_m.detach()
                #     # slots_m_updated = self.slots_m

                #     if training_flag:
                #         self.last_episode_o = torch.mean(self.curr_episode_o_episode, dim=0).detach()
                #         # self.curr_episode_o_episode = self.curr_episode_o_episode.detach()
                #         self.curr_episode_o_episode[0] = slots_updated
                    
                # else:
                #     # slots_m_updated = self.predictor(torch.cat((self.slots_o, self.slots_m_updated), dim=-1))
                #     # self.slots_m_updated = slots_m_updated.detach()
                #     dx = self.query_time(ray_pts, frame_time, self._time, self._time_out)
                #     ray_pts_ = ray_pts + dx
                
                #     dynamics_density = self.update_density(frame_time)
                #     slots_updated, attn = self.slot_attention(self.slots_o, dynamics_density.detach())
                #     density = self.grid_sampler(ray_pts_, self.density)

                #     if training_flag:
                #         self.curr_episode_o_episode[time_index] = slots_updated
            
            else:
                # if start:
                #     dx = 0
                # else:
                #     dx = self.query_time(ray_pts, frame_time, self._time, self._time_out)
                dx = self.query_time(ray_pts, frame_time, self._time, self._time_out)
                ray_pts_ = ray_pts + dx
                density = self.grid_sampler(ray_pts_, self.density)
                # dynamics_density = self.update_density(frame_time)
                # density = self.grid_sampler(ray_pts, dynamics_density)

        # query for alpha w/ post-activation
        if self.density.shape[1] == 1:
            odensity = density[None, :]
        else:
            odensity = density.permute(1,0)

        if training_flag:
            if first_episode:
                # mean_slots_o = torch.cat((slots_updated, slots_m_updated), dim=-1)
                mean_slots_o = slots_updated
            else:
                temp_slots_o = torch.cat((self.last_episode_o, slots_updated), dim=0)
                # pdb.set_trace()
                mean_slots_o = torch.mean(temp_slots_o, dim=0, keepdim=True)
                # mean_slots_o = torch.cat((torch.mean(temp_slots_o, dim=0, keepdim=True), slots_m_updated), dim=-1)
        else:
            # mean_slots_o = torch.cat((self.last_episode_o, self.slots_m_updated), dim=-1)
            mean_slots_o = self.last_episode_o

        rgb_all, density_all, multi_rgb, multi_density = self.decoder(ray_pts_, viewdirs, mean_slots_o, odensity, ray_id, self.act_shift)
        
        # color_variance_loss = torch.mean(torch.abs(multi_rgb[0]-multi_rgb[1]) + torch.abs(multi_rgb[0]-multi_rgb[2]) + torch.abs(multi_rgb[0]-multi_rgb[3]) + 
        # torch.abs(multi_rgb[1]-multi_rgb[2]) + torch.abs(multi_rgb[1]-multi_rgb[3]) +
        # torch.abs(multi_rgb[2]-multi_rgb[3]), dim=-1)

        # color_variance_loss = torch.mean(torch.abs(multi_rgb[0]-multi_rgb[1]) + torch.abs(multi_rgb[0]-multi_rgb[2]) + torch.abs(multi_rgb[0]-multi_rgb[3]) + torch.abs(multi_rgb[0]-multi_rgb[4]) + torch.abs(multi_rgb[0]-multi_rgb[5]) + torch.abs(multi_rgb[0]-multi_rgb[6]) + torch.abs(multi_rgb[0]-multi_rgb[7]) + torch.abs(multi_rgb[0]-multi_rgb[8]) + torch.abs(multi_rgb[0]-multi_rgb[9]) +
        # torch.abs(multi_rgb[1]-multi_rgb[2]) + torch.abs(multi_rgb[1]-multi_rgb[3]) + torch.abs(multi_rgb[1]-multi_rgb[4]) + torch.abs(multi_rgb[1]-multi_rgb[5]) + torch.abs(multi_rgb[1]-multi_rgb[6]) + torch.abs(multi_rgb[1]-multi_rgb[7]) + torch.abs(multi_rgb[1]-multi_rgb[8]) + torch.abs(multi_rgb[1]-multi_rgb[9]) +
        # torch.abs(multi_rgb[2]-multi_rgb[3]) + torch.abs(multi_rgb[2]-multi_rgb[4]) + torch.abs(multi_rgb[2]-multi_rgb[5]) + torch.abs(multi_rgb[2]-multi_rgb[6]) + torch.abs(multi_rgb[2]-multi_rgb[7]) + torch.abs(multi_rgb[2]-multi_rgb[8]) + torch.abs(multi_rgb[2]-multi_rgb[9]) +
        # torch.abs(multi_rgb[3]-multi_rgb[4]) + torch.abs(multi_rgb[3]-multi_rgb[5]) + torch.abs(multi_rgb[3]-multi_rgb[6]) + torch.abs(multi_rgb[3]-multi_rgb[7]) + torch.abs(multi_rgb[3]-multi_rgb[8]) + torch.abs(multi_rgb[3]-multi_rgb[9]) +
        # torch.abs(multi_rgb[4]-multi_rgb[5]) + torch.abs(multi_rgb[4]-multi_rgb[6]) + torch.abs(multi_rgb[4]-multi_rgb[7]) + torch.abs(multi_rgb[4]-multi_rgb[8]) + torch.abs(multi_rgb[4]-multi_rgb[9]) +
        # torch.abs(multi_rgb[5]-multi_rgb[6]) + torch.abs(multi_rgb[5]-multi_rgb[7]) + torch.abs(multi_rgb[5]-multi_rgb[8]) + torch.abs(multi_rgb[5]-multi_rgb[9]) +
        # torch.abs(multi_rgb[6]-multi_rgb[7]) + torch.abs(multi_rgb[6]-multi_rgb[8]) + torch.abs(multi_rgb[6]-multi_rgb[9]) +
        # torch.abs(multi_rgb[7]-multi_rgb[8]) + torch.abs(multi_rgb[7]-multi_rgb[9]) +
        # torch.abs(multi_rgb[8]-multi_rgb[9])
        # , dim=-1)
        # # pdb.set_trace()
        # color_variance_loss = 1/torch.mean(color_variance_loss*1e8, dim=-1)
        
        density = density_all
        rgb = rgb_all

        slots_prob_ori = (multi_density / (torch.sum(multi_density,dim=0,keepdim = True) + 1e-10))   #[7,M]
        slots_prob = slots_prob_ori.permute(1,0) #[M,7]

        # alpha = self.activate_density(density, interval)
        # alpha = 1 - torch.exp(-F.softplus(density + self.act_shift) * interval)
        alpha = 1 - torch.exp(-density * interval)
        if self.fast_color_thres > 0:
            mask = (alpha > self.fast_color_thres)
            # ray_pts = ray_pts[mask]
            ray_id_ = ray_id[mask]
            step_id = step_id[mask]
            density = density[mask]
            alpha = alpha[mask]
            rgb = rgb[mask]
            slots_prob = slots_prob[mask]

        # compute accumulated transmittance
        weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id_, N)
        if self.fast_color_thres > 0:
            mask = (weights > self.fast_color_thres)
            weights = weights[mask]
            alpha = alpha[mask]
            # ray_pts = ray_pts[mask]
            ray_id_ = ray_id_[mask]
            step_id = step_id[mask]
            density = density[mask]
            rgb = rgb[mask]
            slots_prob = slots_prob[mask]

        # pdb.set_trace()
        # Ray marching
        rgb_marched = segment_coo(
                src=(weights.unsqueeze(-1) * rgb),
                index=ray_id_,
                out=torch.zeros([N, 3]),
                reduce='sum')

        
        rgb_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg'])
        ret_dict.update({
            'alphainv_last': alphainv_last,
            'weights': weights,
            'rgb_marched': rgb_marched,
            'raw_alpha': alpha,
            'raw_rgb': rgb,
            'ray_id': ray_id_,
            # 'color_variance_loss': color_variance_loss,
            # 'mmConstra_loss': 0,
            'mean_of_slots': self.last_episode_o[0],
            # 'slots_m': self.slots_m[0].detach()
        })

        if render_kwargs.get('render_depth', False):
            with torch.no_grad():
                depth = segment_coo(
                        src=(weights * step_id),
                        index=ray_id_,
                        out=torch.zeros([N]),
                        reduce='sum')
            ret_dict.update({'depth': depth})

        if render_kwargs.get('segmentation', True):
            # segmentation = torch.zeros([N], dtype=torch.int64)
            # pdb.set_trace()
            contribution = segment_coo(
                src=(weights.unsqueeze(-1) * slots_prob),
                index=ray_id_,
                out=torch.zeros([N, multi_density.shape[0]]),
                reduce='sum') # [M,slots]
            
            seg_contri = torch.cat([alphainv_last.unsqueeze(-1), contribution], dim=-1) # [N, slots+1]
            segmentation = torch.argmax(seg_contri, dim=-1)
            
            ret_dict.update({'segmentation': segmentation})   #[N]
        
        
        return ret_dict


''' Module for the searched coarse geometry
It supports query for the known free space and unknown space.
'''
class MaskCache(nn.Module):
    def __init__(self, path=None, mask_cache_thres=None, mask=None, xyz_min=None, xyz_max=None):
        super().__init__()
        
        if path is not None:
            st = torch.load(path)
            self.mask_cache_thres = mask_cache_thres
            density = F.max_pool3d(st['model_state_dict']['density'], kernel_size=3, padding=1, stride=1)
            # alpha = 1 - torch.exp(-F.softplus(density + st['model_kwargs']['act_shift']) * st['model_kwargs']['voxel_size_ratio'])
            if density.shape[1] != 1:
                raw_masks = F.softplus(density + st['model_kwargs']['act_shift'])
                masks = raw_masks / (raw_masks.sum(dim=1)[:,None] + 1e-5)
                sigma = (raw_masks * masks).sum(dim=1)
                alpha = 1.-torch.exp(-sigma*st['model_kwargs']['voxel_size_ratio'])[:, None]

            else:
                alpha = 1 - torch.exp(-F.softplus(density + st['model_kwargs']['act_shift']) * st['model_kwargs']['voxel_size_ratio'])
            mask = (alpha >= self.mask_cache_thres).squeeze(0).squeeze(0)
            xyz_min = torch.Tensor(st['model_kwargs']['xyz_min'])
            xyz_max = torch.Tensor(st['model_kwargs']['xyz_max'])
        else:
            mask = mask.bool()
            xyz_min = torch.Tensor(xyz_min)
            xyz_max = torch.Tensor(xyz_max)

        self.register_buffer('mask', mask)
        xyz_len = xyz_max - xyz_min
        self.register_buffer('xyz2ijk_scale', (torch.Tensor(list(mask.shape)) - 1) / xyz_len)
        self.register_buffer('xyz2ijk_shift', -xyz_min * self.xyz2ijk_scale)

    @torch.no_grad()
    def forward(self, xyz):
        '''Skip know freespace
        @xyz:   [..., 3] the xyz in global coordinate.
        '''
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(-1, 3)
        mask = render_utils_cuda.maskcache_lookup(self.mask, xyz, self.xyz2ijk_scale, self.xyz2ijk_shift)
        mask = mask.reshape(shape)
        return mask


''' Misc
'''
class Raw2Alpha(torch.autograd.Function):
    @staticmethod
    def forward(ctx, density, shift, interval):
        '''
        alpha = 1 - exp(-softplus(density + shift) * interval)
              = 1 - exp(-log(1 + exp(density + shift)) * interval)
              = 1 - exp(log(1 + exp(density + shift)) ^ (-interval))
              = 1 - (1 + exp(density + shift)) ^ (-interval)
        '''
        exp, alpha = render_utils_cuda.raw2alpha(density, shift, interval);
        if density.requires_grad:
            ctx.save_for_backward(exp)
            ctx.interval = interval
        return alpha

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_back):
        '''
        alpha' = interval * ((1 + exp(density + shift)) ^ (-interval-1)) * exp(density + shift)'
               = interval * ((1 + exp(density + shift)) ^ (-interval-1)) * exp(density + shift)
        '''
        exp = ctx.saved_tensors[0]
        interval = ctx.interval
        return render_utils_cuda.raw2alpha_backward(exp, grad_back.contiguous(), interval), None, None

class Alphas2Weights(torch.autograd.Function):
    @staticmethod
    def forward(ctx, alpha, ray_id, N):
        weights, T, alphainv_last, i_start, i_end = render_utils_cuda.alpha2weight(alpha, ray_id, N)
        if alpha.requires_grad:
            ctx.save_for_backward(alpha, weights, T, alphainv_last, i_start, i_end)
            ctx.n_rays = N
        return weights, alphainv_last

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_weights, grad_last):
        alpha, weights, T, alphainv_last, i_start, i_end = ctx.saved_tensors
        grad = render_utils_cuda.alpha2weight_backward(
                alpha, weights, T, alphainv_last,
                i_start, i_end, ctx.n_rays, grad_weights, grad_last)
        return grad, None, None




class Decoder(nn.Module):
    def __init__(self, n_freq=5, n_freq_view=3, input_dim=33+64, input_ch_dim=21, z_dim=64, n_layers=3, out_ch=3):
        """
        freq: raised frequency
        input_dim: pos emb dim + voxel grid dim
        z_dim: network latent dim
        n_layers: #layers before/after skip connection.
        """
        super().__init__()
        self.n_freq = n_freq
        self.n_freq_view = n_freq_view
        self.out_ch = out_ch
        before_skip = [nn.Linear(input_dim, z_dim), nn.ReLU(True)]
        after_skip = [nn.Linear(z_dim+input_dim, z_dim), nn.ReLU(True)]
        for i in range(n_layers-1):
            before_skip.append(nn.Linear(z_dim, z_dim))
            before_skip.append(nn.ReLU(True))
            after_skip.append(nn.Linear(z_dim, z_dim))
            after_skip.append(nn.ReLU(True))
        self.before = nn.Sequential(*before_skip)
        self.after = nn.Sequential(*after_skip)
        self.after_latent = nn.Linear(z_dim, z_dim) # feature_linear
        # self.after_shape = nn.Linear(z_dim, self.out_ch - 3) # alpha_linear

        self.views_linears = nn.Sequential(nn.Linear(input_ch_dim + z_dim, z_dim//2),
                                           nn.ReLU(True))
        self.color = nn.Sequential(nn.Linear(z_dim//2, z_dim//4), # rgb_linear
                                     nn.ReLU(True),
                                     nn.Linear(z_dim//4, 3))


    def forward(self, sampling_coor, sampling_view, slots, raw_density, ray_id, act_shift, dens_noise=0.):
        """
        1. pos emb by Fourier
        2. for each instances, decode all points from coord and voxel grid corresponding probability
        input:
            sampling_coor: Px3, P = #points, typically P = NxD
            sampling_view: Nx3
            slots: KxC'(64)
            O: KxPxC, K: #max_instances, C: #feat_dim=1
            dens_noise: Noise added to density
        """
        K = raw_density.shape[0]
        P = sampling_coor.shape[0]

        # pdb.set_trace()

        # sampling_coor_ = sampling_coor.expand(K, sampling_coor.shape[0], sampling_coor.shape[1]).flatten(start_dim=0, end_dim=1)  # ((K)xP)x3
        # query_ex = sin_emb(sampling_coor_, n_freq=self.n_freq)  # ((K)xP)x60
        sampling_coor_ = sin_emb(sampling_coor, n_freq=self.n_freq)
        query_ex = sampling_coor_.expand(K, sampling_coor_.shape[0], sampling_coor_.shape[1]).flatten(end_dim=1) # ((K)*P)*33
        
        sampling_view_ = sin_emb(sampling_view, n_freq=self.n_freq_view)[ray_id,:] # P*21
        query_view = sampling_view_.expand(K, sampling_view_.shape[0], sampling_view_.shape[1]).flatten(end_dim=1)
        # sampling_view_ = sampling_view.flatten(start_dim=0, end_dim=1)
        # query_view = sin_emb(sampling_view_, n_freq=self.n_freq_view)
   
        slots_ex = slots.permute(1,0,2).expand(-1, P, -1).flatten(end_dim=1)  # ((K-1)xP)xC
        input = torch.cat([query_ex, slots_ex], dim=1)  # ((K)xP)x(34+C)
        
        # input = query_ex

        tmp = self.before(input)
        tmp = self.after(torch.cat([input, tmp], dim=1))  # ((K)xP)x64
        latent = self.after_latent(tmp)  # ((K)xP)x64
        
        # raw_shape = self.after_shape(tmp).view([K, P]).contiguous()  # ((K)xP)x1 -> (K)xP, density
        h = torch.cat([latent, query_view], -1)
        h = self.views_linears(h)
        raw_rgb = self.color(h).view([K, P, 3]).contiguous()  # ((K)xP)x3 -> (K)xPx3

        raws = torch.cat([raw_rgb, raw_density[..., None]], dim=-1)  # (K)xPx4
        
        raw_masks = F.softplus(raws[:, :, -1:] + act_shift, True)
        raw_sigma = raw_masks + dens_noise * torch.randn_like(raw_masks)

        raw_rgb = (raws[:, :, :3].tanh() + 1) / 2

        if K == 1:
            return raw_rgb[0], raw_sigma.squeeze(-1)[0], raw_rgb, raw_sigma.squeeze(-1)
        
        # else
        masks = raw_masks / (raw_masks.sum(dim=0) + 1e-5)  # KxPx1

        raw_rgb_all = (raw_rgb * masks).sum(dim=0)
        raw_sigma_all = (raw_sigma * masks).sum(dim=0)

        # return raws, masked_raws, unmasked_raws, masks, raw_rgb_all, raw_sigma_all
        # return raws
        return raw_rgb_all, raw_sigma_all.squeeze(-1), raw_rgb, raw_sigma.squeeze(-1)



class Dynamics_v2(nn.Module):
    def __init__(self, n_freq=3, input_dim=1+21+64, hidden=128, out_dim=4):
        super().__init__()
        self.n_freq = n_freq
        self.network = nn.Sequential(nn.Linear(input_dim, hidden),
                                     nn.ReLU(True),
                                     nn.Linear(hidden, hidden),
                                     nn.ReLU(True),
                                     nn.Linear(hidden, out_dim))
        
    def forward(self, x, coord, slots):
        assert len(coord.shape) == 5
        assert len(x.shape) == 5
        assert len(slots.shape) == 3
        # pdb.set_trace()
        N, C, X, Y, Z = x.shape

        x_ = x.permute(0, 2, 3, 4, 1).flatten()[:, None] # [X*Y*Z*M, 1]
        
        coord_ = coord.permute(0, 2, 3, 4, 1).flatten(end_dim=-2)
        query = sin_emb(coord_, n_freq=self.n_freq)[:, None, :].expand(-1, C, -1).flatten(end_dim=-2)
        
        slots_ = slots.expand(X*Y*Z, -1, -1).flatten(end_dim=-2) # [X*Y*Z, M, Z_dim]

        input = torch.cat([x_, slots_, query], dim=-1)
        output = self.network(input).reshape(N, X, Y, Z, C).permute(0, 4, 1, 2, 3)
        return output



def sin_emb(x, n_freq=5, keep_ori=True):

    """
    create sin embedding for 3d coordinates
    input:
        x: Px3
        n_freq: number of raised frequency
    """
    embedded = []
    if keep_ori:
        embedded.append(x)
    emb_fns = [torch.sin, torch.cos]
    freqs = 2. ** torch.linspace(0., n_freq - 1, steps=n_freq)
    for freq in freqs:
        for emb_fn in emb_fns:
            embedded.append(emb_fn(freq * x))
    embedded_ = torch.cat(embedded, dim=1)
    return embedded_
