import torch
from .tineuvox import poc_fre
from pykeops.torch import LazyTensor
import os
from torch_scatter import segment_coo
from torch.utils.cpp_extension import load
from .tineuvox import Alphas2Weights, Raw2Alpha
from itertools import combinations
import numpy as np
import roma
from seaborn import color_palette
from .treeprune import merge_joints


from .utils import project_point_to_image_plane

parent_dir = os.path.dirname(os.path.abspath(__file__))
render_utils_cuda = load(
        name='render_utils_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/render_utils.cpp', 'cuda/render_utils_kernel.cu']],
        verbose=True)

class NoPointsException(Exception):
    def __init__(self, *args: object) -> None:
        super().__init__(*args)

class TransformNet(torch.nn.Module):
    def __init__(self, input_dim, num_components, num_params_per_component, num_layers=3, hidden_dim=256):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_components = num_components
        self.num_params_per_component = num_params_per_component
        self.out_dim = num_components * num_params_per_component
        self.register_buffer('rotation_switch_mask', torch.arange(0, num_components).long())

        layers = []
        for i in range(num_layers-1):
            if i == 0:
                layers.append(torch.nn.Linear(input_dim, self.hidden_dim))
                layers.append(torch.nn.ReLU())
            else:
                layers.append(torch.nn.Linear(self.hidden_dim, self.hidden_dim))
                layers.append(torch.nn.ReLU())

        layers.append(torch.nn.Linear(self.hidden_dim, self.out_dim, bias=False))
        self.net = torch.nn.Sequential(*layers)

    def forward(self, x):
        b, _ = x.shape
        out = self.net(x)
        if b > 1:
            out = out.reshape(b, self.num_components, self.num_params_per_component)
        else:
            out = out.reshape(self.num_components, self.num_params_per_component)

        return out

class PointWarper(torch.nn.Module):
    def __init__(self,
                #  component_num,
                 t_dim,
                 canonical_pcd,
                 H,
                 joints,
                 bones,
                 num_layers=5,
                 over_parameterized_rot=True,
                 ):
        super().__init__()
        self.t_dim = t_dim
        self.params_per_compoent = 4
        self.canonical_pcd = canonical_pcd
        self.H = H
        self.num_layers = num_layers
        self.over_parameterized_rot = over_parameterized_rot
        self.init_tree(joints, bones, old=False)
            
        # self.joint_to_bones = {k: [] for k in range(len(joints))}
        # for i, bone in enumerate(bones):
        #     self.joint_to_bones[bone[0]].append(i)
        self.hom_row = torch.tensor([0,0,0,1], dtype=torch.float32)
        self.transform_net = TransformNet(t_dim, len(joints) + 1, self.params_per_compoent, num_layers=self.num_layers)
        self.register_buffer('rot_mask', torch.zeros(len(joints), dtype=torch.bool))
        self.register_buffer('sibling_mask', torch.arange(0, len(joints)).long())
       
    def kwargs(self):
        return {
            't_dim': self.t_dim,
            'params_per_compoent': self.params_per_compoent,
            'component_num': self.component_num,
            'over_parameterized_rot': self.over_parameterized_rot,
        }

    def init_tree(self, joints, bones, old=True):

        if old:
            self.bones = bones
            self.parent_joint = {b[1]: b[0] for b in bones}
            self.child_joints = {k: [] for k in range(len(joints))}
            for k in self.parent_joint.keys():
                parent_k = self.parent_joint[k]
                self.child_joints[parent_k].append(k)

            # Accelerated tree.
            parent_indices = []
            for i in range(len(self.bones)):
                j = i + 1
                inds = []
                while j >= 0:
                    inds += [j]
                    j = self.parent_joint.get(j, -1)
                parent_indices += [inds[::-1]]
            max_depth = np.max([len(x) for x in parent_indices])
            self.parent_indices = torch.zeros((len(self.bones), max_depth), dtype=torch.long) - 1
            for i,inds in enumerate(parent_indices):
                self.parent_indices[i,:len(inds)] = torch.from_numpy(np.array(inds)).to(self.parent_indices.device, dtype=self.parent_indices.dtype)
            self.parent_joint_ex = torch.from_numpy(np.array([self.parent_joint.get(i, -1) for i in range(len(self.bones)+1)])).to(self.parent_indices.device, dtype=self.parent_indices.dtype)
        else:
            self.bones = bones

            self.parent_joint = {b[1]: b[0] for b in self.bones}
            self.child_joints = {k: [] for k in range(len(joints))}
            for k in self.parent_joint.keys():
                parent_k = self.parent_joint[k]
                self.child_joints[parent_k].append(k)

            # Accelerated tree.
            parent_indices = [[0]]
            for i in range(len(self.bones)):
                j = i + 1
                inds = []
                while j >= 0:
                    inds += [j]
                    j = self.parent_joint.get(j, -1)
                parent_indices += [inds[::-1]]
            max_depth = np.max([len(x) for x in parent_indices])
            self.parent_indices = torch.zeros((len(parent_indices), max_depth), dtype=torch.long) - 1
            for i,inds in enumerate(parent_indices):
                self.parent_indices[i,:len(inds)] = torch.from_numpy(np.array(inds)).to(self.parent_indices.device, dtype=self.parent_indices.dtype)
            self.parent_joint_ex = torch.from_numpy(np.array([self.parent_joint.get(i, 0) for i in range(len(parent_indices))])).to(self.parent_indices.device, dtype=self.parent_indices.dtype)

    def remove_regressor_weights(self, rotations_to_keep, rotation_switch_mask):
        # times = torch.linspace(0., 1., 300).unsqueeze(-1)
        # times_embed = poc_fre(times, time_poc)
        # params = self.transform_net(times_embed)

        # gt = params[:, :-1][:, rotations_to_keep]
        ## Remove weights

        # Create mask
        rotations_to_keep = torch.tensor([[i, i, i] for i in rotations_to_keep]).flatten()
        transformations_to_keep = torch.cat([rotations_to_keep, torch.tensor([True, True, True])]) # Always Keep global translation
        out_feature_num = torch.sum(transformations_to_keep).item()

        # Actually remove weights
        self.transform_net.net[-1].weight = torch.nn.Parameter(self.transform_net.net[-1].weight[transformations_to_keep])
        self.transform_net.net[-1].out_features = out_feature_num
        self.transform_net.num_components = out_feature_num // 3

        # Account for change in order
        # Root rotation does not change order
        # rotation_switch_mask += 1
        # rotation_switch_mask = torch.concat([torch.tensor([0]), rotation_switch_mask])
        # Account for root translation (at the end of the tenosr) Root translation does not change order
        rotation_switch_mask = torch.concat([rotation_switch_mask, torch.tensor([len(rotation_switch_mask)])])
        self.transform_net.rotation_switch_mask = rotation_switch_mask.long()

        # params = self.transform_net(times_embed)[:,:-1]

    def Rodrigues(self, rvec, theta=None):
        # Neural Volumes
        if rvec.shape[-1] == 3:
            theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
            rvec = rvec / theta[:, None]
        elif rvec.shape[-1] == 4:
            theta = rvec[:, -1]
            rvec = rvec[:, :3]
            rvec = rvec / torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))[:, None]
        else:
            raise ValueError()

        costh = torch.cos(theta)
        sinth = torch.sin(theta)
        return torch.stack((
            rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
            rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
            rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,

            rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
            rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
            rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,

            rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
            rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
            rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3), theta

    def quat_weighted_avg(self, Q, w, norm=False):
        A = (w[:,:,None,None] * torch.bmm(Q.unsqueeze(2), Q.unsqueeze(1))).sum(1)
        if norm:
            A /= w.sum(dim=1)[:,None,None]
        if torch.any(torch.svd(A, compute_uv=False)[1] < torch.finfo(A.dtype).eps) :
            print(1)
        eigenValues, eigenVectors = torch.linalg.eig(A)
        sorted_idx = torch.argsort(torch.real(eigenValues), dim=-1, descending=True)[:,None,:]

        # torch.save(A, './A.pt')
        # torch.save(Q, './Q.pt')
        # torch.save(w, './w.pt')

        return torch.real(torch.take_along_dim(eigenVectors, sorted_idx, dim=2)[:,:,0])


    @classmethod
    def matrix_chain_product(cls, matrix_chain: torch.tensor) -> torch.tensor:
        """ Recursive binary tree. """
        chain_len = matrix_chain.shape[1]
        if chain_len == 1:
            return matrix_chain
        sub_a = cls.matrix_chain_product(matrix_chain[:,:chain_len//2])
        sub_b = cls.matrix_chain_product(matrix_chain[:,chain_len//2:])
        return sub_a @ sub_b


    def calc_rec_abs_T_fast(self, R_t: torch.tensor, joints: torch.tensor) -> torch.tensor:
        """
        Recalculates all transforms.
        """

        #### PREVIOUS IMPLEMENTATION ####

        # self.init_tree(joints, self.bones, old=True)

        # Start: Compatibility package.
        # This is unnecesarily complicated and could be removed.
        joints_old = torch.cat((self.hom_row[None, :3], joints), 0)
        # Align joints with their bones.
        joints_old = joints_old[self.parent_joint_ex + 1]
        # End: Compatibility package.

        M_bones_old = torch.cat((torch.cat((R_t, joints_old[...,None] + R_t @ -joints_old[...,None]), -1), self.hom_row[None,None].repeat(R_t.shape[0],1,1)), -2)
        M_bones_old = torch.cat((torch.eye(4)[None], M_bones_old), 0)
        M_paths_old = M_bones_old[self.parent_indices + 1]
        out_old = self.matrix_chain_product(M_paths_old)[:,0]

        #### NEW IMPLEMENTATION ####

        # self.init_tree(joints, self.bones, old=False)
        # # Translation is defined by the parent and hence shared by all siblings (unlike the rotation which is from the child.)
        # joints = joints[self.parent_joint_ex]

        # # Each node transform is a rotation (from the child) around the position (from the parent)
        # # hom(None, t_i) @ hom(R_i) @ hom(None, t_i, inv=True)
        # M_bones = torch.cat((torch.cat((R_t, joints[...,None] + R_t @ -joints[...,None]), -1), self.hom_row[None,None].repeat(R_t.shape[0],1,1)), -2)
        # M_bones[0,:3,-1] = 0.
        # # Handle -1 fake identity nodes by shifting index by 1.
        # M_bones = torch.cat((torch.eye(4)[None], M_bones), 0)
        # M_paths = M_bones[self.parent_indices + 1]
        # # Cummulative product.
        # out = self.matrix_chain_product(M_paths)[:,0]

        return out_old
    
    # def get_bone_vecs(self, joints):
    #     bone_vec = []
    #     for bone in self.bones:
    #         vec = joints[bone[1]] - joints[bone[0]]
    #         bone_vec.append(vec.unsqueeze(0)) # normalisation happens later
    #         # bone_vec.append((vec / torch.norm(vec)).unsqueeze(0))
    #     return torch.concat(bone_vec, dim=0)

    def get_thetas(self, ts_embed):
        params = self.transform_net(ts_embed)
        rot_params = params[:, :-1, :3]
        shape = rot_params.shape[:2]
        rot_params = rot_params.reshape(torch.mul(*shape), 3)
        _, thetas = self.Rodrigues(rot_params)
        
        return thetas.reshape(shape, 3)

    def set_rotation_mask(self, rotations_to_keep):
        # mask = ~rotations_to_keep
        # mask[0] = False # always keep root
        mask = ~rotations_to_keep
        if self.rot_mask is not None:
            mask = torch.logical_or(mask, self.rot_mask)
        self.rot_mask = mask

    def set_sibling_mask(self, sibling_mask):
        self.sibling_mask = sibling_mask.long()

    def forward(self, weights, joints, t=None, rot_params=None, global_t=None, get_frames=False, avg_procrustes=True):
        assert (t is None) ^ (rot_params is None)

        # Get time-dependent rotation and translation
        with torch.profiler.record_function("transform_net"):
            if rot_params is None:
                params = self.transform_net(t.unsqueeze(0))
                self.prev_params = params
                global_t = params[-1, :3]
                if self.over_parameterized_rot:
                    rot_params = params[:len(joints), :4]
                else:
                    rot_params = params[:len(joints), :3]
                # bone_vecs = self.get_bone_vecs(joints)
                # bone_rot_params = params[1:len(joints), -1]
                # bone_rot_params = torch.concat([bone_vecs, bone_rot_params.unsqueeze(-1)], dim=-1)

        with torch.profiler.record_function("calc_rec_abs_T"):
            self.prev_global_t = global_t
            R_t, self.prev_thetas = self.Rodrigues(rot_params)

            R_t = R_t[self.sibling_mask]
            if self.rot_mask is not None:
                R_t[self.rot_mask] = torch.eye(3)

            self.prev_thetas = self.prev_thetas # [self.sibling_mask]
            # R_t_perpendicular, _ = self.Rodrigues(bone_rot_params)
            # R_t_perpendicular = torch.concat([torch.eye(3).unsqueeze(0), R_t_perpendicular])
            # R_t = torch.bmm(R_t, R_t_perpendicular)
            # R_t = torch.concat([R_t[0].unsqueeze(0), R_t_perpendicular])

            # Do recusrive bone transformations
            bone_Ts = self.calc_rec_abs_T_fast(R_t, joints)

        with torch.profiler.record_function("weighted_G_tw"):
            # Apply weights to transformation matrix
            weighted_G_tw = (bone_Ts * weights[:, :, None, None]).sum(dim=1)

            if avg_procrustes:
                weighted_T = weighted_G_tw[:,:3,-1,None]
                weighted_R = roma.special_procrustes(weighted_G_tw[:,:3,:3])

                weighted_G_tw = torch.cat((weighted_R, weighted_T), -1)
                weighted_G_tw = torch.cat((weighted_G_tw, self.hom_row[None,None].repeat(weighted_G_tw.shape[0], 1, 1)), -2)

            # weighted_T = (bone_Ts[:,:3,-1] * weights[:, :, None]).sum(dim=1).unsqueeze(-1)

            # weighted_R = roma.rotmat_to_unitquat(bone_Ts[:,:3,:3])
            # weighted_R = self.quat_weighted_avg(weighted_R, weights)
            # weighted_R = roma.unitquat_to_rotmat(weighted_R)

            # if ((torch.det(weighted_R) - 1).abs() > 0.001).any():
                # print(1)
            
            

            # Transform points
            xyz = self.canonical_pcd
            xyzh = torch.concat([xyz, torch.ones((len(xyz), 1))], axis=-1)
            xyzh = torch.bmm(weighted_G_tw, xyzh.unsqueeze(-1)).squeeze(-1) 
            xyz = xyzh[:,:3]

            if global_t is not None:
                xyz = xyz + global_t
        
        if not get_frames:
            return xyz.contiguous()
        else:
            return xyz.contiguous(), weighted_G_tw

class TemporalPoints(torch.nn.Module):
    def __init__(self,
            canonical_pcd,
            canonical_alpha,
            canonical_feat,
            canonical_rgbs,
            skeleton_pcd,
            joints,
            hierachy,
            bones,
            xyz_min,
            xyz_max,
            tineuvox,
            # joint_neighbours,
            neighbours=8,
            timebase_pe=8,
            eps=1e-6,
            stepsize=None,
            voxel_size=None,
            fast_color_thres=0,
            embedding='full',
            frozen_view_dir=None,
            over_parameterized_rot=True,
            avg_procrustes=True,
            **kwargs):
        super(TemporalPoints, self).__init__()
        self.canonical_pcd = canonical_pcd
        self.skeleton_pcd = skeleton_pcd
        self.hierachy = hierachy
        self.bones = bones
        self.bone_arap_mask = torch.tensor(bones).reshape(-1)
        # self.joint_neighbours = joint_neighbours
        # self.component_num = component_num
        self.register_buffer('xyz_min', torch.Tensor(xyz_min))
        self.register_buffer('xyz_max', torch.Tensor(xyz_max))
        self.register_buffer('WEIGHTS_DIRECT', torch.Tensor([1]))
        self.timebase_pe = timebase_pe
        self.eps = torch.tensor(eps)
        self.t_dim = 1 + self.timebase_pe * 2
        self.stepsize = stepsize
        self.voxel_size = voxel_size
        self.fast_color_thres = fast_color_thres
        self.embedding = embedding
        self.avg_procrustes = avg_procrustes
        self.over_parameterized_rot = over_parameterized_rot
            # self.register_buffer('viewdirs_emb', viewdirs_emb[None])

        self.weights = torch.nn.Parameter(self._weights_from_bones(joints, bones, canonical_pcd, add_noise=True, noise_var=0, soft_weights=True, add_zero_weight=True), requires_grad=True)
        self.forward_warp = PointWarper(canonical_pcd=canonical_pcd, t_dim=self.t_dim, H=self.hierachy, joints=joints, bones=bones, over_parameterized_rot=over_parameterized_rot)
        self.original_joints = torch.nn.Parameter(joints.to(torch.float32), requires_grad=False)
        self.joints = torch.nn.Parameter(joints.to(torch.float32), requires_grad=True)
        self.canonical_feat = torch.nn.Parameter(canonical_feat, requires_grad=True)
        self.theta = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=True)
        self.theta_weight = torch.nn.Parameter(torch.tensor([0.1]), requires_grad=True)
        self.merging_dict = None
        self.merging_mat = None

        gammas = torch.ones(len(self.canonical_pcd))
        self.gammas = torch.nn.Parameter(gammas + torch.randn_like(gammas) * 1e-2, requires_grad=True)

        self.pruned_joints = torch.zeros(len(self.joints), dtype=bool)

        self.register_buffer('flat_merging_rules', torch.arange(0, len(self.joints)))
        self.register_buffer('sibling_merging_rules', torch.zeros(len(self.joints), dtype=bool))

        # Only for direct point cloud rendering
        self.canonical_rgbs = torch.nn.Parameter(canonical_rgbs, requires_grad=False)
        self.canonical_alpha = torch.nn.Parameter(canonical_alpha, requires_grad=False)
 

        self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))

        # Define neighbourhood for canonical points and nn distances
        self.neighbours = neighbours
        xyz1 = LazyTensor(self.canonical_pcd[:, None, :])
        xyz2 = LazyTensor(self.canonical_pcd[None, :, :])
        D_ij = ((xyz1 - xyz2) ** 2).sum(-1)
        self.nn_i = D_ij.argKmin(dim=1, K=self.neighbours)
        self.nn_distance = torch.sqrt(((self.canonical_pcd[:,None,:] - self.canonical_pcd[self.nn_i,:])**2).sum(-1) + self.eps) # distance to nearest neighbour per point

        self.err = self.nn_distance[:,0].mean()
        self.mean_min_distance = self.nn_distance[:,1].mean()
        self.cube_mask = ((self.nn_distance[:,1:7].mean(dim=1) >= (self.mean_min_distance - self.err)) & (self.nn_distance[:,1:7].mean(dim=1) <= (self.mean_min_distance + self.err)))

        self.s = self.nn_distance[self.cube_mask][:,1:7].mean()

        # Define joint neighbourhood distances
        self.og_joint_distance = (self.original_joints[self.bone_arap_mask][0::2,:] 
                                - self.original_joints[self.bone_arap_mask][1::2,:])

        # Define feature net (as in PointNerf)
        if self.embedding == 'full' or self.embedding == 'raw_embed':
            feat_input_dim = self.canonical_feat.shape[-1] + 3 + 3 * tineuvox.posbase_pe * 2
        elif self.embedding == 'raw' or self.embedding == 'rot':
            feat_input_dim = self.canonical_feat.shape[-1] + 3
        else:
            raise NotImplementedError()
        feat_ouput_dim = self.canonical_feat.shape[-1]
        feat_width = self.canonical_feat.shape[-1]
        feat_depth = 4
        self.feat_net = torch.nn.Sequential(
            torch.nn.Linear(feat_input_dim, feat_width), torch.nn.LeakyReLU(inplace=True),
            *[
                torch.nn.Sequential(torch.nn.Linear(feat_width, feat_width), torch.nn.LeakyReLU(inplace=True))
                for _ in range(feat_depth-2)
            ],
            torch.nn.Linear(feat_width, feat_ouput_dim), torch.nn.LeakyReLU(inplace=True)
            )

        # Get rgb net and density from tineuvox
        self.act_shift = tineuvox.act_shift
        self.rgbnet = tineuvox.rgbnet
        self.densitynet = tineuvox.densitynet
        self.featurenet = tineuvox.featurenet
        self.timenet = tineuvox.timenet
        self.view_poc = tineuvox.view_poc
        self.time_poc = tineuvox.time_poc
        self.pos_poc = tineuvox.pos_poc
        self.grid_poc = tineuvox.grid_poc
        self.feat_only = tineuvox.feat_only
        self.no_view_dir = tineuvox.no_view_dir 
        self.tineuvox = tineuvox
        # self.deformation_net = Deformation(W=self.tineuvox.net_width, D=self.tineuvox.defor_depth, input_ch=3+3*self.tineuvox.posbase_pe*2, input_ch_time=self.tineuvox.deformation_net.input_ch_time, input_ch_views=0)
        self.register_buffer('xyz_max_canonical', self.canonical_pcd.max(dim=0)[0])
        self.register_buffer('xyz_min_canonical',  self.canonical_pcd.min(dim=0)[0])
        # self.deformation_net_loaded = False
        self.deformation_net = None

        self.frozen_view_dir = frozen_view_dir
        if frozen_view_dir is not None:
            viewdirs_emb = poc_fre(frozen_view_dir, self.view_poc)
            self.viewdirs_emb = torch.nn.Parameter(viewdirs_emb[None], requires_grad=True)


    def get_kwargs(self):
        return {
            'canonical_pcd': self.canonical_pcd,
            'skeleton_pcd': self.skeleton_pcd,
            'canonical_alpha': self.canonical_alpha,
            'canonical_feat': self.canonical_feat,
            'canonical_rgbs': self.canonical_rgbs,
            'joints': self.joints,
            'hierachy': self.hierachy,
            'bones': self.bones,
            'neighbours': self.neighbours,
            'timebase_pe': self.timebase_pe,
            'eps': self.eps,
            'stepsize': self.stepsize,
            'weights': self.weights,
            'xyz_min': self.xyz_min.cpu().numpy(),
            'xyz_max': self.xyz_max.cpu().numpy(),
            'tineuvox': self.tineuvox,
            'voxel_size': self.voxel_size,
            'fast_color_thres': self.fast_color_thres,
            'deformation_net': self.deformation_net,
            'embedding': self.embedding,
            'frozen_view_dir': self.frozen_view_dir,
            'over_parameterized_rot':self.over_parameterized_rot,
            'avg_procrustes': self.avg_procrustes,
        }

    def reinitialise_weights(self):
        self.weights.data = self._weights_from_bones(self.joints, self.bones, self.canonical_pcd, add_noise=True, noise_var=0, soft_weights=True, add_zero_weight=False)
        self.theta_weight.data = torch.tensor([0.1])

    def dist_batch(self, p, a, b):
        assert len(a) == len(b), "Same batch size needed for a and b"

        p = p[None, :, :]
        s = b - a
        w = p - a[:, None, :]
        ps = (w * s[:, None, :]).sum(-1)
        res = torch.zeros((a.shape[0], p.shape[1]), dtype=p.dtype)

        # ps <= 0
        ps_smaller_mask = ps <= 0
        lower_mask = torch.where(ps_smaller_mask)
        res[lower_mask] += torch.norm(w[lower_mask], dim=-1)

        l2 = (s * s).sum(-1)
        # ps > 0 and ps >= l2
        ps_mask = ~ps_smaller_mask

        temp_mask_l2 = ps >= l2[:, None]
        upper_mask = torch.where(ps_mask & temp_mask_l2)
        res[upper_mask] += torch.norm(p[0][upper_mask[1]] - b[upper_mask[0]], dim=-1)

        # ps > 0 and ps < l2
        within_mask = torch.where(ps_mask & ~temp_mask_l2)
        res[within_mask] += torch.norm(
            p[0][within_mask[1]] - (a[within_mask[0]] + (ps[within_mask] / l2[within_mask[0]]).unsqueeze(-1) * s[within_mask[0]]), dim=-1)

        return res
    
    def _weights_from_bones(self, joints, bones, pcd, add_noise=False, noise_var=0, val=1, soft_weights=True, add_zero_weight=False):
        bone_distances = self.dist_batch(
            pcd,
            torch.cat([joints[bone[0]].unsqueeze(0) for bone in bones], dim=0),
            torch.cat([joints[bone[1]].unsqueeze(0) for bone in bones], dim=0))
            
        if soft_weights:
            # weights = torch.softmax((1 / (bone_distances + 1e-6)).T, dim=-1)
            # weights = (1 / (bone_distances + 1e-6)).T
            weights = (1 / (0.5 * torch.e ** bone_distances + self.eps)).T.contiguous()

            # weight_new = torch.zeros_like(weights)
            # vals, indices = weights.topk(dim=-1, k=4)

            # for i in range(len(weights)):

            #     weight_new[i][indices[i]] = vals[i]
            # # weights = torch.softmax((1 / (bone_distances**2 + 1e-6)).T, dim=-1)
            # weights = weight_new
        else:
            bone_argmin = torch.argmin(bone_distances, axis=0)
            weights = torch.zeros((len(bone_argmin), len(bones)))
            weights[torch.arange(len(bone_argmin)), bone_argmin] = val
        
        if add_zero_weight:
            weights = torch.cat([torch.zeros((len(weights), 1)), weights], dim=-1)
        
        if add_noise:
            weights = weights + torch.randn_like(weights) * noise_var

        return weights
    
    def simplify_skeleton(self, times, deg_threshold=10, mass_threshold=0.0, update_skeleton=False, five_percent_heuristic=False, visualise_canonical=False):
        # Calculate time-dependent rotation angles
        times_embed = poc_fre(times, self.time_poc)
        params = self.forward_warp.transform_net(times_embed)
        if self.over_parameterized_rot:
            rot_angles = params[:, :len(self.joints), -1]
        else:
            rot_angles = (params[:, :len(self.joints), :3]**2).sum(-1).sqrt() % (2 * np.pi)
        rotations = torch.zeros((len(times), len(self.joints), 3, 3))
        for ji in range(len(self.joints)):
            if self.over_parameterized_rot:
                temp_rot, _ = self.forward_warp.Rodrigues(params[:, ji, :])
            else:
                temp_rot, _ = self.forward_warp.Rodrigues(params[:, ji, :3])
            rotations[:, ji, :] = temp_rot
        
        rotation_similarity_mat = torch.eye(len(self.joints)).type(torch.bool)
        for i in range(len(self.joints)):
            for j in range(len(self.joints)):
                if j >= i:
                    continue
                res = self._are_rotations_similar(
                    rotations[:,i,...], rotations[:,j,...], 
                    deg_threshold=deg_threshold, five_percent_heuristic=five_percent_heuristic)
                rotation_similarity_mat[i,j] = res
                rotation_similarity_mat[j,i] = res
        
        if five_percent_heuristic:
            # 5% heuristic
            th = int(len(times) * 0.05)
            res = (torch.rad2deg(rot_angles.abs()) >= deg_threshold).sum(dim=0) #
            zero_motion = res <= th
        else:
            # Avg heuristic
            deg_stds = torch.rad2deg((rot_angles**2).mean(dim=0)) # How much it differs from the canonical pose (zero rad)
            zero_motion = deg_stds <= deg_threshold

        # Create pruning masks based on rotation and mass of component
        weights = self.get_weights()
        # no_mass = (weights.sum(dim=0) / len(self.weights)) < mass_threshold
        # # no_mass = torch.zeros_like(no_mass)
        # prune_bones = torch.logical_or(zero_motion, no_mass)
        prune_bones = zero_motion
        prune_bones[0] = False # never prune (imaginary) root (bone)

        # Prune joints
        joints = self.joints.detach().cpu().numpy()
        bones = self.bones
        new_joints, new_bones, merging_rules, joints_to_keep, rotations_to_keep, rotation_switch_mask, sibling_transfer_rules = merge_joints(joints, bones, prune_bones, rotation_similarity_mat, convert_merging_rules=False)
        rotations_to_keep = torch.tensor(rotations_to_keep)

        if visualise_canonical:
            self.joints.data = torch.tensor(new_joints)
            self.bones = new_bones

        if update_skeleton:
            # NOTE: Not working
            # raise NotImplementedError()
            self.joints.data = torch.tensor(new_joints)
            self.bones = new_bones

            # Re-initialise kinematic tree and set rotations mask (it is easier to use
            # a mask for rotations instead of pruning the MLP weights directly due to the 
            # strided reshaping that happens in the pose regressor)
            # Simply masking (should) work as the joints are in the same order as before
            self.forward_warp.init_tree(new_joints, new_bones, old=False)
            self.forward_warp.remove_regressor_weights(rotations_to_keep, torch.tensor(rotation_switch_mask))
        else:
            self.forward_warp.set_rotation_mask(~prune_bones)
            self.forward_warp.set_sibling_mask(torch.tensor(sibling_transfer_rules))

        # Solve for pruned weights
        current_weights = self.get_weights()
        target_weights = torch.zeros_like(current_weights)
        for from_idx, to_idx in enumerate(merging_rules): 
            target_weights[:, to_idx] += current_weights[:, from_idx]
        
        if update_skeleton:
            # Remove pruned weights
            target_weights = target_weights[:, rotations_to_keep]
            current_weights = self.weights[:, rotations_to_keep]

        # Merging Mat approach
        num_weights = self.weights.shape[-1]
        self.flat_merging_rules = torch.tensor(self.flatten_merging_rules(merging_rules))
        self.sibling_merging_rules = torch.tensor(sibling_transfer_rules)
        self.merging_mat = torch.zeros(num_weights, num_weights, num_weights)
        for i in range(num_weights):
            mask = (self.flat_merging_rules == i)
            self.merging_mat[i] = torch.eye(num_weights) * mask

        # Solve for new weights
        # theta_weight = torch.max(self.eps, self.theta_weight)
        # new_weights = solve_x_softmax_iteratively(target_weights, current_weights, alpha=theta_weight, iters=100)
        # if update_skeleton:
        #     new_weights = new_weights[:, rotation_switch_mask] # account for changed order
        # self.weights.data = new_weights

        print(f"frozen joints {prune_bones.sum()} from {len(prune_bones)} ")
        print(f"Pruned {len(joints) - len(new_joints)} joints")

        return joints, bones, new_joints, new_bones, prune_bones, merging_rules, rotations_to_keep, res
    
    def flatten_merging_rules(self, merging_rules):
        endpoints = []
        for i in range(len(merging_rules)):
            new_indx = i
            while True:
                new_indx = merging_rules[new_indx]
                if new_indx == merging_rules[new_indx]:
                    endpoints.append(new_indx)
                    break
        return endpoints


    def _are_rotations_similar(self, rot1, rot2, deg_threshold=20, five_percent_heuristic=False):
        angle = torch.norm(
            roma.rotmat_to_rotvec(
                rot1 @ torch.transpose(rot2, 1, 2))
            , dim=-1)
        
        if not five_percent_heuristic:
            deg_std = torch.rad2deg(torch.sqrt((angle**2).mean(dim=0)))
            return deg_std <= deg_threshold
        else:
            th = int(len(rot1) * 0.05)
            res = torch.rad2deg(angle) >= deg_threshold #
            return res.sum(dim=0) <= th
    
    def repose(self, rot_params):
        return self.forward_warp(self.get_weights(), self.joints, rot_params=rot_params)

    def sample_ray(self, rays_o, rays_d, near, far, stepsize, xyz_min=None, xyz_max=None, **render_kwargs):
        '''Sample query points on rays.
        All the output points are sorted from near to far.
        Input:
            rays_o, rayd_d:   both in [N, 3] indicating ray configurations.
            near, far:        the near and far distance of the rays.
            stepsize:         the number of voxels of each sample step.
        Output:
            ray_pts:          [M, 3] storing all the sampled points.
            ray_id:           [M]    the index of the ray of each point.
            step_id:          [M]    the i'th step on a ray of each point.
        '''
        rays_o = rays_o.contiguous()
        rays_d = rays_d.contiguous()
        if xyz_min is None:
            xyz_min = self.xyz_min
        if xyz_max is None:
            xyz_max = self.xyz_max
        stepdist = stepsize * self.voxel_size
        ray_pts, mask_outbbox, ray_id, step_id, _, _, _ = render_utils_cuda.sample_pts_on_rays(
            rays_o, rays_d, xyz_min, xyz_max, near, far, stepdist)
        mask_inbbox = ~mask_outbbox
        ray_pts = ray_pts[mask_inbbox]
        ray_id = ray_id[mask_inbbox]
        step_id = step_id[mask_inbbox]

        return ray_pts, ray_id, step_id,mask_inbbox
    
    def get_weights(self):
        # raw_weights = self.weights.permute(1,0)[self.get_valid_joints_mask()[1:]].permute(1,0)
        theta_weight = torch.max(self.eps, self.theta_weight)
        weights = torch.softmax(self.weights / theta_weight, dim=-1).permute(1,0)
        # return weights.permute(1,0)
        
        if self.merging_mat == None:
            self.merging_mat = torch.zeros(len(weights), len(weights), len(weights))
            self.flat_merging_rules = torch.arange(0, len(self.joints))
            for i in range(weights.shape[0]):
                mask = (self.flat_merging_rules == i)
                self.merging_mat[i] = torch.eye(len(weights)) * mask

        merged_weights = torch.bmm(self.merging_mat, weights.unsqueeze(0).repeat(len(weights), 1, 1)).sum(1)
        
        return merged_weights.permute(1,0)
    
    def load_3DConv(self, path):
        raise NotImplementedError()
        # state_dict = torch.load(path)['model_state_dict']

        # # self.load_state_dict(state_dict['model_state_dict'])
        # keep_params = ['feat_net', 'deformation_net', 'canonical_feat']
        # pretrained_dict = {}
        # for k, v in state_dict.items():
        #     for kp in keep_params:
        #         if kp in k:
        #             pretrained_dict[k] = v
        #             break
                
        # self.load_state_dict(pretrained_dict, strict=False)
        # self.deformation_net_loaded = True

    def aggregate_pts(self, t_hat_pcd, rotated_frames, flow, query_radius, cam_per_ray, render_pcd_direct, render_pcd_direct_only, direct_eps, render_weights, render_kwargs, calc_min_max=True):
        N = len(render_kwargs['rays_o'])
        render_neighbours = self.neighbours
        render_lbs_weights = None

        with torch.profiler.record_function("sample_ray"):
            assert render_kwargs is not None
            if calc_min_max:

                xyz_min, xyz_max = torch.min(t_hat_pcd, dim=0)[0] - query_radius, torch.max(t_hat_pcd, dim=0)[0] + query_radius
            else:
                xyz_min, xyz_max = None, None
            ray_pts, ray_id, step_id, _ = self.sample_ray(xyz_min=xyz_min, xyz_max=xyz_max, **render_kwargs)

        if len(ray_pts) == 0:
            raise NoPointsException("No points.")
        
        ## Find nearest neighbours
        with torch.profiler.record_function("knn"):
            ray_pts_lazy = LazyTensor(ray_pts[:, None, :].detach())
            shifted_xyz_lazy = LazyTensor(t_hat_pcd[None, :, :].detach())
            D_ij = ((ray_pts_lazy - shifted_xyz_lazy) ** 2).sum(-1)
            to_nn, s_i = D_ij.Kmin_argKmin(dim=1, K=render_neighbours)

        with torch.profiler.record_function("knn-post"):
            nn_mask = torch.where(to_nn[:,-1] <= query_radius)[0]                
            s_i = s_i[nn_mask, :]
            ray_id = ray_id[nn_mask]
            step_id = step_id[nn_mask]
            ray_pts = ray_pts[nn_mask]
            # to_nn = to_nn[nn_mask] # WARNING: Has no gradients!!!
            rel_p = ray_pts[:,None,:] - t_hat_pcd[s_i,:]
            to_nn = (rel_p**2).sum(-1)
        
        if len(s_i) == 0:
            raise NoPointsException("No points.")

        with torch.profiler.record_function("feat_net"):
            features_k = self.canonical_feat[s_i,:]
            rotated_frames_k = rotated_frames[s_i,:,:]
        
            ## Directly render point cloud based on frozen rgb and alpha
            rgbs_direct = None
            alpha_direct = None
            if render_pcd_direct:
                direct_eps = 0.05
                sig = torch.tensor(self.mean_min_distance * direct_eps)
                w_direct = torch.exp(-(to_nn**2) / (2 * (sig)**2 + 1e-12))
                w = 1 / (to_nn + self.eps)
                w_direct_density = (torch.tensor(1./render_neighbours) * w_direct).unsqueeze(-1)
                w_direct = w_direct / (w_direct.sum(dim=-1) + 1e-12)[:, None]
                w_direct = w_direct.unsqueeze(-1)
                alpha_k_direct = self.canonical_alpha[s_i].unsqueeze(-1)
                rgbs_k_direct = self.canonical_rgbs[s_i,:]
                
                rgbs_direct = (w_direct * rgbs_k_direct).sum(dim=1)
                alpha_direct = (w_direct_density * alpha_k_direct).sum(dim=1).squeeze(-1)

            if render_pcd_direct_only:
                rgbs = rgbs_direct
                alpha = alpha_direct
                w = w_direct
            else:
                ## POINT-NERF APPROACH
                w = 1 / (to_nn + self.eps)
                # w = torch.exp(-(to_nn**2) / (2 * (sig)**2 + 1e-12))
                w = w / w.sum(dim=-1)[:, None]
                w = w.unsqueeze(-1)
                features_k_hat = []

                if self.embedding == 'full' or self.embedding == 'rot':
                    rotated_frames_flat = rotated_frames_k[...,:3,:3].reshape(-1,3,3)
                    rel_p_flat = rel_p.reshape(-1, 3).unsqueeze(-1)
                    rel_p_canonical = torch.bmm(rotated_frames_flat, rel_p_flat).squeeze(-1)
                
                    if self.embedding == 'full': 
                        rel_p_emb = poc_fre(rel_p_canonical, self.pos_poc)
                    else:
                        rel_p_emb = rel_p_canonical
                
                if self.embedding == 'raw_embed':
                    rel_p_emb = poc_fre(rel_p, self.pos_poc)
                    rel_p_emb = rel_p_emb.reshape(-1, rel_p_emb.shape[-1])

                elif self.embedding == 'raw':
                    rel_p_emb = rel_p.reshape(-1, 3)
                
                feat_input = torch.concat([rel_p_emb, features_k.reshape(-1, features_k.shape[-1])], dim=-1)
                out = self.feat_net(feat_input)

                features_k_hat = out.reshape(*rotated_frames_k.shape[:2], out.shape[-1])
                h_feature = (features_k_hat * w).sum(dim=1)

                if flow is not None:
                    flow_k = flow[:, s_i,:]
                    gather_indx = cam_per_ray[ray_id].long() 
                    flow_k = torch.gather(flow_k, 0, gather_indx[None, ..., None].expand((-1, -1, *flow_k.shape[-2:])).to(flow_k.device)).squeeze() # This is not working
                    flow = (flow_k * w).sum(dim=1)

                with torch.profiler.record_function("densitynet"):
                    density_result = self.densitynet(h_feature)
                    interval = render_kwargs['stepsize'] * self.tineuvox.voxel_size_ratio
                    alpha = self.tineuvox.activate_density(density_result, interval)

                with torch.profiler.record_function("rgbnet"):
                    if self.no_view_dir:
                        rgb_logit = self.rgbnet(h_feature)

                    elif self.frozen_view_dir is not None:
                        viewdirs_emb_reshape = self.viewdirs_emb.expand(len(ray_id), -1)
                        rgb_logit = self.rgbnet(h_feature, viewdirs_emb_reshape)

                    else:
                        viewdirs_emb = poc_fre(render_kwargs['viewdirs'], self.view_poc)
                        viewdirs_emb_reshape = viewdirs_emb[ray_id]
                        rgb_logit = self.rgbnet(h_feature, viewdirs_emb_reshape)
                    rgbs = torch.sigmoid(rgb_logit)

                if render_weights:
                    render_lbs_weights = self._last_weights[s_i,:]
                    render_lbs_weights = (render_lbs_weights * w).sum(dim=1)

            return rgbs, alpha, rgbs_direct, alpha_direct, render_lbs_weights, flow, ray_pts, ray_id, step_id, N 

    def sample_thetas(self, tmin, tmax, num=50, reduction='five_percent', deg_threshold=15):
        ts = (tmax - tmin) * torch.rand(num) + tmin
        ts = torch.rand(num)[:, None]
        ts_embed = poc_fre(ts, self.time_poc)
        thetas = self.forward_warp.get_thetas(ts_embed)

        if reduction == 'five_percent':
            th = int(num * 0.05)
            res = torch.rad2deg(thetas) >= deg_threshold 
            thetas = res.sum(dim=0) <= th
        elif reduction == 'mean':
            thetas = thetas.mean(dim=0)
        else:
            raise NotImplementedError()
        
        return thetas


    def forward(self, t, render_image=False, render_depth=False, render_kwargs=None, query_radius=0.01,
                render_weights=False, benchmark=False, rot_params=None, render_pcd_direct=False, 
                kinematic_warp=True, direct_eps=1e-1, flow_t_delta=None, poses=None, Ks=None, cam_per_ray=None, calc_min_max=True, render_conv=False):
        assert (t is None) ^ (rot_params is None)
        flow = None
        grid = None
        # Forward-warp canonical pcd
        with torch.profiler.record_function("poc_fre"):
            if rot_params is None:
                t_embed = poc_fre(t, self.time_poc)
                times_features = self.timenet(t_embed)
            else:
                t_embed = None
                times_features = self.timenet(poc_fre(torch.tensor([0.]), self.time_poc))
        with torch.profiler.record_function("forward_warp"):
            
            # if self.deformation_net_loaded or not kinematic_warp:
            #     times_features = self.timenet(t_embed)
            #     conv_t_hat_pcd, conv_rotated_frames, grid = self.deformation_net(times_features[None,None,:], self.canonical_pcd)
            if render_conv:
                conv_t_hat_pcd, conv_rotated_frames, grid = self.deformation_net(times_features[None,None,:], self.canonical_pcd)
            else:
                conv_t_hat_pcd, conv_rotated_frames, grid = None, None, None 
            flow_diff = None
            if kinematic_warp:
                if self.WEIGHTS_DIRECT[0]: # Sample weights directly from the points
                    weights = self.get_weights()
                else:
                    weights = self._weights_from_bones(self.joints, self.bones, self.canonical_pcd, add_noise=False, noise_var=0, soft_weights=True)
                self._last_weights = weights
                t_hat_pcd, rotated_frames = self.forward_warp(weights, self.joints, t_embed, get_frames=True, rot_params=rot_params, avg_procrustes=self.avg_procrustes)
                if render_conv:
                    flow_diff = t_hat_pcd - conv_t_hat_pcd.clone().detach()
                    # flow_diff = t_hat_pcd - conv_t_hat_pcd

                rotated_frames = torch.inverse(rotated_frames)
                # rotated_frames = torch.transpose(rotated_frames, 2, 1)
                if flow_t_delta is not None:
                    t2_embed = poc_fre(t + flow_t_delta, self.time_poc)
                    t2_hat_pcd, _ = self.forward_warp(weights, self.joints, t2_embed, get_frames=True, rot_params=rot_params)

                    p2 = project_point_to_image_plane(t2_hat_pcd, poses, Ks)
                    p1 = project_point_to_image_plane(t_hat_pcd, poses, Ks)
                    # Mask Compability
                    p2[:,:,0] = 511 - p2[:,:,0]
                    # p2 = p2.flip(-1)

                    p1[:,:,0] = 511 - p1[:,:,0]
                    # p1 = p1.flip(-1)
                    # Mask Compability end

                    flow = p2 - p1
            else:
                # times_features = torch.repeat_interleave(self.timenet(t_embed).unsqueeze(0), len(self.canonical_pcd), dim=0)
                # pcd_embed = poc_fre(self.canonical_pcd, self.tineuvox.pos_poc)
                # t_delta = self.deformation_net(pcd_embed, times_features)
                # t_hat_pcd = self.canonical_pcd + t_delta
                # rotated_frames = torch.eye(4).unsqueeze(0).repeat(len(t_hat_pcd), 1, 1)
                # self._last_weights = torch.ones(len(t_hat_pcd), 1)

                # times_features = self.timenet(t_embed)
                # conv_t_hat_pcd, conv_rotated_frames, grid = self.deformation_net(times_features[None,None,:], self.canonical_pcd)
                # t_hat_pcd = conv_t_hat_pcd
                # rotated_frames = conv_rotated_frames
                pass
            # self._last_weights = torch.ones(len(t_hat_pcd), 1)

            
            # if self.deformation_net_loaded:
            #     flow_diff = t_hat_pcd - conv_t_hat_pcd.detach()

        rgb_marched = None
        alphainv_last = None
        rgb_teacher = None
        alpha_teacher = None
        rgb_marched_direct = None
        ret_dict = {}
        if render_image:
            try:
                # Render kinematic
                render_pcd_direct_only = False
                render_pcd_direct = True
                rgbs, alpha, rgbs_direct, alpha_direct, render_lbs_weights, flow, ray_pts, ray_id, step_id, N  = self.aggregate_pts(
                    t_hat_pcd, rotated_frames, flow, query_radius, cam_per_ray,
                    render_pcd_direct, render_pcd_direct_only, direct_eps, render_weights, render_kwargs, calc_min_max=calc_min_max)
                ray_id_direct = torch.clone(ray_id)

                if render_conv:
                    # Render Conv direct
                    render_pcd_direct_only = True
                    render_pcd_direct = True
                    rgbs_conv, alpha_conv, _, _, _, _, _, ray_id_conv, _, _  = self.aggregate_pts(
                            conv_t_hat_pcd, conv_rotated_frames, flow, query_radius, cam_per_ray,
                            render_pcd_direct, render_pcd_direct_only, direct_eps, render_weights, render_kwargs, calc_min_max=calc_min_max)
                    rgbs_direct = rgbs_conv
                    alpha_direct = alpha_conv
                    ray_id_direct = ray_id_conv
                # else:
                #     rgbs_direct, alpha_direct, ray_id_direct = None, None, None

            except NoPointsException as npe:
                return {
                    'rgb_marched': torch.ones(len(render_kwargs['rays_o']), 3),
                    'rgb_marched_direct': torch.ones(len(render_kwargs['rays_o']), 3),
                    'flow': torch.zeros(len(render_kwargs['rays_o']), 2),
                    'depth': torch.zeros(len(render_kwargs['rays_o'])),
                    'weights': torch.ones(len(render_kwargs['rays_o']), 3),
                    't_hat_pcd': t_hat_pcd,
                    'conv_t_hat_pcd': conv_t_hat_pcd,
                    'flow_diff': None,
                    'alphainv_last': None,
                    'rgb_teacher': None,
                    'alpa_teacher': None,
                    'grid': grid,
                }
            # rgb_teacher = (rgbs_direct - rgbs).abs()
            # alpha_teacher = (alpha_direct - alpha).abs()
            
            # ray_id_direct = torch.clone(ray_id)
                    
            with torch.profiler.record_function("pre-mask"):
                if self.fast_color_thres > 0:
                    mask = torch.where(alpha > self.fast_color_thres)[0]
                    ray_id = ray_id[mask]
                    step_id = step_id[mask]
                    alpha = alpha[mask]
                    rgbs = rgbs[mask]

                    if flow is not None:
                        flow = flow[mask]

                    if render_weights:
                        render_lbs_weights = render_lbs_weights[mask]
                    
                    if alpha_direct is not None:
                        mask_direct = torch.where(alpha_direct > self.fast_color_thres)[0]
                        ray_id_direct = ray_id_direct[mask_direct]
                        alpha_direct = alpha_direct[mask_direct]
                        rgbs_direct = rgbs_direct[mask_direct]

                # compute accumulated transmittance
            with torch.profiler.record_function("Alphas2Weights"):
                weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N)
                if alpha_direct is not None:
                    weights_direct, alphainv_last_direct = Alphas2Weights.apply(alpha_direct, ray_id_direct, N)

            with torch.profiler.record_function("post-mask"):
                if self.fast_color_thres > 0:
                    mask = torch.where(weights > self.fast_color_thres)[0]
                    weights = weights[mask]
                    alpha = alpha[mask]
                    ray_id = ray_id[mask]
                    step_id = step_id[mask]
                    rgbs = rgbs[mask]

                    if flow is not None:
                        flow = flow[mask]

                    if render_weights:
                        render_lbs_weights = render_lbs_weights[mask]
                    
                    if alpha_direct is not None:
                        mask_direct = torch.where(weights_direct > self.fast_color_thres)[0]
                        ray_id_direct = ray_id_direct[mask_direct]
                        alpha_direct = alpha_direct[mask_direct]
                        rgbs_direct = rgbs_direct[mask_direct]
                        weights_direct = weights_direct[mask_direct]

            with torch.profiler.record_function("segment_coo"):
                rgb_marched = segment_coo(
                    src=(weights.unsqueeze(-1) * rgbs),
                    index=ray_id,
                    out=torch.zeros([N, 3]),
                    reduce='sum')

                rgb_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg'])

                if alpha_direct is not None:
                    rgb_marched_direct = segment_coo(
                        src=(weights_direct.unsqueeze(-1) * rgbs_direct),
                        index=ray_id_direct,
                        out=torch.zeros([N, 3]),
                        reduce='sum')

                    rgb_marched_direct += (alphainv_last_direct.unsqueeze(-1) * render_kwargs['bg'])

                if render_depth:
                    with torch.no_grad():
                        depth = segment_coo(
                                src=(weights * step_id),
                                index=ray_id,
                                out=torch.zeros([N]),
                                reduce='sum')
                    ret_dict.update({'depth': depth})

                if flow is not None:
                    flow_marched = segment_coo(
                            src=(weights.unsqueeze(-1) * flow),
                            index=ray_id,
                            out=torch.zeros([N, 2]),
                            reduce='sum')
                    # flow_marched += (alphainv_last.unsqueeze(-1) * 0)
                    
                    ret_dict.update({'flow': flow_marched})

        ret_dict.update({
            't_hat_pcd': t_hat_pcd,
            'rgb_marched': rgb_marched,
            'alphainv_last': alphainv_last,
            'grid': grid,
            'conv_t_hat_pcd': conv_t_hat_pcd,
            # 'rgb_teacher': rgb_teacher,
            # 'alpha_teacher': alpha_teacher,
            'rgb_marched_direct': rgb_marched_direct,
            'flow_diff': flow_diff
        })

        if render_weights:
            weight_mask = self.get_weights().sum(dim=0) > 0
            cols = torch.tensor(color_palette("hls", weight_mask.sum()))
            gen  = torch.Generator(device=render_lbs_weights.device)
            gen.manual_seed(0)
            cols = cols[torch.randperm(cols.shape[0], generator=gen)]
            # 
            # cols = torch.rand((render_lbs_weights.shape[-1], 3),)
            col_per_weight = 0
            for ci, wi in enumerate(torch.where(weight_mask)[0]):
                col_per_weight += cols[ci, None] * render_lbs_weights[:, wi, None]

            col_per_weight = col_per_weight.type(type(rgbs))

            # soft march
            w_marched = segment_coo(
                src=(weights.unsqueeze(-1) * col_per_weight),
                index=ray_id,
                out=torch.zeros([N, 3]),
                reduce='sum')
            w_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg'])
            ret_dict.update({'weights': w_marched})

        return ret_dict
    
    def get_neighbour_weight_tv_loss(self):
        diff = self._last_weights[:,None,:] - self._last_weights[self.nn_i,:]
        return torch.abs(diff).mean()
        # return diff.pow(2).mean()
        # return ((torch.abs(diff) + diff**2) / 2).mean()
    
    def get_time_tv_loss(self, t):
        t_embed = poc_fre(t, self.time_poc)
        params = self.forward_warp.transform_net(t_embed.unsqueeze(0))
        prev_params = self.forward_warp.prev_params

        return ((prev_params - params)**2).sum()
    
    def get_weight_sparsity_loss(self):
        return -(
            self._last_weights * torch.log(self._last_weights + self.eps) + 
            (1 - self._last_weights) * torch.log(1 - self._last_weights + self.eps)).mean()
        # return torch.abs(torch.sigmoid(self.weights)).mean()
        # return torch.abs(self.weights).mean()

    def collision_loss(self, warped_pcd):
        # xyz1 = LazyTensor(warped_pcd[:, None, :])
        # xyz2 = LazyTensor(warped_pcd[None, :, :])
        # D_ij = ((xyz1 - xyz2) ** 2).sum(-1)
        # nn_i = D_ij.argKmin(dim=1, K=1)
        # warped_nn_distance = torch.sqrt(((warped_pcd[:,None,:] - warped_pcd[nn_i,:])**2).sum(-1) + self.eps)

        # return ((self.mean_min_distance - warped_nn_distance)**2).sum()
        xyz1 = LazyTensor(warped_pcd[:, None, :])
        xyz2 = LazyTensor(warped_pcd[None, :, :])
        D_ij = ((xyz1 - xyz2) ** 2).sum(-1)
        nn_i = D_ij.argKmin(dim=1, K=self.neighbours)
        warped_nn_distance = torch.sqrt(((warped_pcd[:,None,:] - warped_pcd[nn_i,:])).pow(2).sum(-1) + self.eps)

        W = self.get_weights().clone().detach()
        Wnn = W[nn_i]

        return ((W[:,None,:] - Wnn).pow(2).sum(-1) * 1/(warped_nn_distance + self.eps)).mean()


        # return -(W[:,None,:] * torch.log(Wnn + self.eps) + (1 - W[:,None,:]) * torch.log(1 - Wnn + self.eps)).mean()









        return ((self.nn_distance - warped_nn_distance)).abs().sum()
        # return (self.nn_distance - warped_nn_distance).abs().sum()

    def get_arap_loss(self, warped_pcd, c=0.03):
        warped_nn_distance = torch.sqrt((warped_pcd[:,None,:] - warped_pcd[self.nn_i,:]).pow(2).sum(-1) + self.eps)
        return (self.nn_distance - warped_nn_distance).abs().sum()
        # return ((self.nn_distance - warped_nn_distance).pow(2)).sum()
    
    def get_arap_area_loss(self, warped_pcd, c=0.03):
        warped_nn_distance = torch.sqrt((warped_pcd[:,None,:] - warped_pcd[self.nn_i,:]).pow(2).sum(-1) + self.eps)
        return ((self.nn_distance - warped_nn_distance).pow(2)).sum()

    def get_joint_arap_loss(self):
        joint_distance = (self.joints[self.bone_arap_mask][0::2,:] - self.joints[self.bone_arap_mask][1::2,:])
        # return (self.og_joint_distance - joint_distance).abs().sum()
        return ((self.og_joint_distance - joint_distance)**2).sum()

    def get_joint_chamfer_loss(self):
        _, c2 = self.get_chamfer_loss(self.skeleton_pcd, self.joints, c=None, get_raw=True)
        return c2.sum()

    def _rho(self, x, c):
        return (2 * (x / c)**2) / ((x / c)**2 + 4)

    def get_chamfer_loss(self, pcd1, pcd2, N=None, M=None, c=0.03, get_raw=False):
        if N is not None:
            N_i = torch.randint(0, pcd1.shape[0], (N,)).long().to(pcd1.device)
            pcd1 = pcd1[N_i]

        if M is not None:
            M_i = torch.randint(0, pcd2.shape[0], (M,)).long().to(pcd2.device)
            pcd2 = pcd2[M_i]

        xyz1 = LazyTensor(pcd1[:, None, :])
        xyz2 = LazyTensor(pcd2[None, :, :])
        D_ij = ((xyz1 - xyz2) ** 2).sum(-1)
        nn_i1 = D_ij.argKmin(dim=1, K=1)
        nn_i2 = D_ij.argKmin(dim=0, K=1)
        nn_distance1 = ((pcd1[:, None, :] - pcd2[nn_i1, :])**2).sum(-1)
        nn_distance2 = ((pcd2[:, None, :] - pcd1[nn_i2, :])**2).sum(-1)

        if get_raw:
            return nn_distance1, nn_distance2

        if c is None:
            loss = nn_distance1.mean() + nn_distance2.mean()
        else:
            loss = self._rho(nn_distance1, c).mean() + self._rho(nn_distance2, c).mean()

        return loss
    
    def get_we_entropy_loss(self, t_min, t_max):
        #### weight roation penalisation test
        # r = self.sample_thetas(num=50)
        # r = r % (2 * torch.pi)
        # offset = (torch.pi / 180) * 15
        # target = 1 - (r >= offset)
        # target = 1 - (torch.sigmoid(r - offset)).clip(0)
        target = self.sample_thetas(t_min, t_max, num=50).to(torch.float32)
        target = target[None,:].repeat(len(self._last_weights), 1)
        return -(torch.log(target + self.eps) * (1 - self._last_weights)).sum()
        #### 
    
    def get_batch_chamfer_loss(self, pcd1, pcd2, N=None, M=None):
        """batch-wise chamfer loss.
        Args:
            pcd1 (torch.Tensor): (B, N, 3) tensor of N 3D points.
            pcd2 (torch.Tensor): (B, M, 3) tensor of M 3D points.
            N (int): Number of points to sample from pcd1.
            M (int): Number of points to sample from pcd2.
        """
        assert len(pcd1) == len(pcd2)

        if N is not None:
            N_i = torch.randint(0, pcd1.shape[1], (N,), device=pcd1.device).long()
            pcd1 = pcd1[:, N_i]

        if M is not None:
            M_i = torch.randint(0, pcd2.shape[1], (M,), device=pcd1.device).long()
            pcd2 = pcd2[:, M_i]

        xyz1 = LazyTensor(pcd1[:, :, None, :])
        xyz2 = LazyTensor(pcd2[:, None, :, :])
        D_ij = ((xyz1 - xyz2) ** 2).sum(-1)
        nn_i1 = D_ij.argKmin(dim=2, K=1)
        nn_i2 = D_ij.argKmin(dim=1, K=1)

        idx = nn_i1.unsqueeze(-1).expand(-1, -1, -1, pcd2.shape[-1])
        nn_distance1 = (pcd1[:, :, None, :] - torch.gather(pcd2[:,:,None,:], 1, idx)).pow(2)

        idx = nn_i2.unsqueeze(-1).expand(-1, -1, -1, pcd1.shape[-1])
        nn_distance2 = (pcd2[:, :, None, :] - torch.gather(pcd1[:,:,None,:], 1, idx)).pow(2)

        return nn_distance1.sum(-1).mean() + nn_distance2.sum(-1).mean()

        # nn_distance1 = 0
        # nn_distance2 = 0
        # for i in range(len(pcd1)):
        #     nn_distance1 += (pcd1[i, :, None, :] - pcd2[i, nn_i1[i], :])**2
        #     nn_distance2 += (pcd2[i, :, None, :] - pcd1[i, nn_i2[i], :])**2

        # return nn_distance1.sum(-1).mean() + nn_distance2.sum(-1).mean()


    def get_transformation_regularisation_loss(self, d=0.0873):
        t = self.forward_warp.prev_global_t.abs()
        thetas = self.forward_warp.prev_thetas.abs()
        return (torch.abs(t).sum() + thetas.sum()) / len(thetas + 1)
        # return (torch.abs(t).sum() + self._rho(thetas, c=0.03).sum()) / len(thetas + 1)
    
        # return (torch.abs(t).sum() + torch.abs(thetas).sum())  / len(thetas + 1)
        # return  torch.abs(t).sum() + (thetas * torch.sigmoid((d - thetas) / d)).sum()


    def feature_consistency_loss(self, pred_pcd, gt_pcd, pred_feat, gt_feat, N=2000, M=2000):
        # gt_pcd = gt_pcd.cpu()
        # pred_pcd = pred_pcd.cpu()
        # pred_feat = pred_feat.cpu()
        # gt_feat = gt_feat.cpu()
        if N is not None:
            N_mask = torch.randperm(len(pred_pcd))[:N]
            pred_pcd = pred_pcd[N_mask]
            pred_feat = pred_feat[N_mask] 

        if M is not None:
            M_mask = torch.randperm(len(gt_pcd))[:M]
            gt_pcd = gt_pcd[M_mask]
            gt_feat = gt_feat[M_mask]

        pred_feat = pred_feat / (torch.norm(pred_feat, dim=-1, keepdim=True) + self.eps)
        gt_feat = gt_feat / (torch.norm(gt_feat, dim=-1, keepdim=True) + self.eps)

        feat1 = gt_feat[:, None, :]
        feat2 = pred_feat[None, :, :]
        S_ij = (feat1 * feat2).sum(-1)
        theta = torch.max(self.eps, self.theta)
        S_ij = torch.softmax(S_ij / theta, dim=-1)
        softmax_pos = (S_ij.unsqueeze(-1) * pred_pcd.unsqueeze(0)).sum(1)

        distance = ((softmax_pos - gt_pcd)**2).mean()

        return distance
