#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
from lietorch import SO3, SE3, Sim3, LieGroupParameter
import numpy as np
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
from torch import nn
import os
import matplotlib.pyplot as plt
from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement
from utils.sh_utils import RGB2SH, SH2RGB
from simple_knn._C import distCUDA2
from utils.graphics_utils import BasicPointCloud
from utils.general_utils import strip_symmetric, build_scaling_rotation
from utils.sh_utils import eval_sh
from scipy.spatial.transform import Rotation as R
import math



from diff_gaussian_rasterization import (
    GaussianRasterizationSettings,
    GaussianRasterizer,
)
from utils.camera_conversion import (matrix_to_quaternion,)


class GaussianModel_4D:

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm

        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.scaling_t_activation = torch.exp
        self.scaling_t_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize

    def __init__(self, sh_degree: int, view_dependent=True, time_duration=[0.0, 1.0],
                 no_time_split=True, t_grad=True, contract=True, t_init=0.001, big_point_threshold=0.4, 
                 cycle=0.2, velocity_decay=1.0, random_init_point=200000):
        self.active_sh_degree = 0
        self.max_sh_degree = sh_degree
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(0)
        self.denom = torch.empty(0)
        self.optimizer = None
        self.percent_dense = 0
        self.spatial_lr_scale = 0
        self.rotate_xyz = False
        self.rotate_seq = False
        self.seq_idx = 0
        self.view_dependent = view_dependent
        self._omega = torch.empty(0)
        self._motion = torch.empty(0)

        self._t = torch.empty(0)
        self._scaling_t = torch.empty(0)
        self.t_gradient_accum = torch.empty(0)

        self.time_duration = time_duration
        self.no_time_split = no_time_split
        self.t_grad = t_grad
        self.contract = contract
        self.t_init = t_init
        self.big_point_threshold = big_point_threshold

        self.T = cycle 
        self.random_init_point = random_init_point

        self.localize = False

        self.setup_functions()

    def capture(self):
        return (
            self.active_sh_degree,
            self._xyz,
            self._features_dc,
            self._features_rest,
            self._scaling,
            self._rotation,
            self._opacity,
            self._t,
            self._scaling_t,
            self._motion,
            self._omega,
            self.max_radii2D,
            self.xyz_gradient_accum,
            self.t_gradient_accum,
            self.denom,
            self.optimizer.state_dict(),
            self.spatial_lr_scale,
            self.P,
            self.T,
        )

    def restore(self, model_args, training_args):
        (self.active_sh_degree,
         self._xyz,
         self._features_dc,
         self._features_rest,
         self._scaling,
         self._rotation,
         self._opacity,
         self._t,
         self._scaling_t,
         self._motion,
         self._omega,
         self.max_radii2D,
         xyz_gradient_accum,
         t_gradient_accum,
         denom,
         opt_dict,
         self.spatial_lr_scale,
         self.P,
         self.T,) = model_args
        self.training_setup(training_args)
        self.xyz_gradient_accum = xyz_gradient_accum
        self.t_gradient_accum = t_gradient_accum
        self.denom = denom
        self.optimizer.load_state_dict(opt_dict)

    def prune_with_mask(self, new_mask=None):
        self.prune_points(self.mask)  # all the mask with value 1 are pruned
        if new_mask is not None:
            self.mask = new_mask
        else:
            self.mask[:] = 1  # all updatable
        self.remove_grad_mask()
        self.apply_grad_mask(self.mask)
        self.update_anchor()

    @property
    def get_scaling(self):
        if self.localize:
            return self.scaling_activation(self._scaling[self.mask])
        else:
            return self.scaling_activation(self._scaling)
    
    @property
    def get_scaling_t(self):
        if self.localize:
            return self.scaling_t_activation(self._scaling_t[self.mask])
        else:
            return self.scaling_t_activation(self._scaling_t)

    @property
    def get_rotation(self):
        if self.localize:
            return self.rotation_activation(self._rotation[self.mask])
        else:
            return self.rotation_activation(self._rotation)

    def get_rotation_t(self, t):
        rotation = self._rotation + (t - self._t) * self._omega
        if self.localize:
            return self.rotation_activation(rotation[self.mask])
        else:
            return self.rotation_activation(rotation)


    def get_xyz_motion(self, t):
        xyz = self._xyz + self._motion[:, 0:3] * (t - self._t) + self._motion[:, 3:6] * ((t - self._t) ** 2) + self._motion[:, 6:9] * ((t - self._t)**3)
        if self.rotate_xyz:
            xyz = self.P[0].retr().act(xyz)
            if self.localize:
                return xyz[self.mask]
            else:
                return xyz
        elif self.rotate_seq:
            xyz = self.P[self.seq_idx].retr().act(xyz)
            if self.localize:
                return xyz[self.mask]
            else:
                return xyz
        else:
            if self.localize:
                return xyz[self.mask]
            else:
                return xyz

    @property
    def get_xyz(self):
        xyz = self._xyz.clone()
        if self.rotate_xyz:
            xyz = self.P[0].retr().act(xyz)
            if self.localize:
                return xyz[self.mask]
            else:
                return xyz
        elif self.rotate_seq:
            xyz = self.P[self.seq_idx].retr().act(xyz)
            if self.localize:
                return xyz[self.mask]
            else:
                return xyz
        else:
            if self.localize:
                return xyz[self.mask]
            else:
                return xyz

    @property
    def get_t(self):
        if self.localize:
            return self._t[self.mask]
        else:
            return self._t
        

    def get_RT(self, idx=None):
        if getattr(self, "P", None) is None:
            return torch.eye(4, device="cuda")

        if self.rotate_xyz:
            Rt = self.P[0].retr().matrix()
        else:
            if idx is None:
                Rt = self.P[self.seq_idx].retr().matrix()
            else:
                Rt = self.P[idx].retr().matrix()

        return Rt.squeeze()

    def set_seq_idx(self, idx):
        if idx < 0:
            self.rotate_seq = False
            self.rotate_xyz = False
        else:
            self.seq_idx = idx

    @property
    def get_features(self):
        if self.localize:
            features_dc = self._features_dc[self.mask]
            features_rest = self._features_rest[self.mask]
        else:
            features_dc = self._features_dc
            features_rest = self._features_rest

        return torch.cat((features_dc, features_rest), dim=1)

    @property
    def get_features_noview(self):
        features_dc = self._features_dc.squeeze()
        return features_dc

    @property
    def get_opacity(self):
        if self.localize:
            return self.opacity_activation(self._opacity[self.mask])
        else:
            return self.opacity_activation(self._opacity)

    def get_covariance(self, scaling_modifier=1):
        if self.localize:
            return self.covariance_activation(
                self.get_scaling[self.mask], scaling_modifier, self._rotation[self.mask]
            )
        else:
            return self.covariance_activation(
                self.get_scaling, scaling_modifier, self._rotation
            )
    
    def get_marginal_t(self, timestamp):
        return torch.exp(-0.5 * (self.get_t - timestamp) ** 2 / self.get_scaling_t ** 2)

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        if self.view_dependent:
            fused_color = RGB2SH(torch.tensor(
                np.asarray(pcd.colors)).float().cuda())
        else:
            fused_color = torch.tensor(np.asarray(pcd.colors)).float().cuda()
        features = torch.zeros(
            (fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
        features[:, :3, 0] = fused_color[:, :3]
        features[:, 3:, 1:] = 0.0
        # features = torch.cat([features, features], dim=1)

        time = (np.random.rand(pcd.points.shape[0], 1) * 1.2 - 0.1) * (
                self.time_duration[1] - self.time_duration[0]) + self.time_duration[0]
  
        if self.t_init < 1:
            random_times = (torch.rand(fused_point_cloud.shape[0]-pcd.points.shape[0], 1, device="cuda") * 1.2 - 0.1) * (
                    self.time_duration[1] - self.time_duration[0]) + self.time_duration[0]
            pts_times = torch.from_numpy(time.copy()).float().cuda()
            fused_times = torch.cat([pts_times, random_times], dim=0)
        else:
            fused_times = torch.full_like(fused_point_cloud[..., :1],
                                            0.5 * (self.time_duration[1] + self.time_duration[0]))

        # print("Number of points at initialisation : ", fused_point_cloud.shape[0])

        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(
            np.asarray(pcd.points)).float().cuda()), 0.0000001)
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        dist_t = torch.full_like(fused_times, (self.time_duration[1] - self.time_duration[0])*self.t_init)
        scales_t = self.scaling_t_inverse_activation(torch.sqrt(dist_t))

        omega = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        self._omega = nn.Parameter(omega.requires_grad_(True))

        motion = torch.zeros((fused_point_cloud.shape[0], 9), device="cuda")# x1, x2, x3,  y1,y2,y3, z1,z2,z3
        self._motion = nn.Parameter(motion.requires_grad_(True))

        opacities = inverse_sigmoid(
            0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))

        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(
            1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(
            features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True))
        self._scaling = nn.Parameter(scales.requires_grad_(True))
        self._rotation = nn.Parameter(rots.requires_grad_(True))
        self._opacity = nn.Parameter(opacities.requires_grad_(True))

        self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
        self._t = nn.Parameter(fused_times.requires_grad_(True))
        self._scaling_t = nn.Parameter(scales_t.requires_grad_(True))

        self.set_mask(
            torch.ones(
                self._opacity.shape[0],
                dtype=torch.bool,
                device="cuda",
                requires_grad=False,
            )
        )
        self.apply_grad_mask(self.mask)

    def training_setup(self, training_args, fix_pos=False,
                       fix_feat=False, fit_pose=False):
        self.percent_dense = training_args.percent_dense
        self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device="cuda")
        self.t_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self._xyz.shape[0], 1), device="cuda")
        l = []

        _xyz_lr = training_args.position_lr_init * \
            self.spatial_lr_scale if not fix_pos else 0.0
        feat_lr_factor = 1.0 if not fix_feat else 0.0

        l += [
            {'params': [self._xyz], 'lr': _xyz_lr, "name": "xyz"},
            {'params': [self._features_dc], 'lr': training_args.feature_lr * feat_lr_factor, "name": "f_dc"},
            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0 * feat_lr_factor, "name": "f_rest"},
            {'params': [self._opacity], 'lr': training_args.opacity_lr * feat_lr_factor, "name": "opacity"},
            {'params': [self._scaling], 'lr': training_args.scaling_lr * feat_lr_factor, "name": "scaling"},
            {'params': [self._rotation], 'lr': training_args.rotation_lr * feat_lr_factor, "name": "rotation"},
            {'params': [self._t], 'lr': training_args.t_lr_init, "name": "t"},
            {'params': [self._scaling_t], 'lr': training_args.scaling_t_lr, "name": "scaling_t"},
            {'params': [self._motion], 'lr':  training_args.position_lr_init * self.spatial_lr_scale * 0.5 * training_args.movelr , "name": "motion"},
            {'params': [self._omega], 'lr': training_args.omega_lr, "name": "omega"},
       ]

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)
        
        final_decay = training_args.position_lr_final / training_args.position_lr_init

        self.t_scheduler_args = get_expon_lr_func(lr_init=training_args.t_lr_init,
                                                    lr_final=training_args.t_lr_init * final_decay,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.max_iterations)

        if fit_pose:
            rotation_lr_factor = 1.0 if (fix_pos and fix_feat) else 1.0

            if self.rotate_seq:
                self.camera_optimizer = []
                for idx in range(len(self.P)):
                    l_cam = [{'params': [self.P[idx]],
                              'lr': training_args.rotation_lr, "name": "R"},]
                    self.camera_optimizer.append(torch.optim.Adam(l_cam, lr=0.0, eps=1e-15))
      
            else:
                l_cam = [{'params': [self.P],
                        'lr': training_args.rotation_lr, "name": "R"},]
                self.camera_optimizer = [torch.optim.Adam(l_cam, lr=0.0, eps=1e-15)]
            self.camera_scheduler_args = get_expon_lr_func(lr_init=training_args.rotation_lr,
                                                           lr_final=training_args.rotation_lr * 0.1,
                                                           lr_delay_mult=0.1,
                                                           max_steps=training_args.position_lr_max_steps)
        else:
            self.camera_optimizer = None

    def training_setup_fix_position(self, training_args, gaussian_rot=True):
        self.percent_dense = training_args.percent_dense
        self.xyz_gradient_accum = torch.zeros(
            (self.get_xyz.shape[0], 1), device="cuda")
        self.t_gradient_accum = torch.zeros(
            (self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        if gaussian_rot:
            lr_factor = 1.0
        else:
            lr_factor = 0.1

        l = [
            {'params': [self.P[0]],
                'lr': training_args.rotation_lr, "name": "R"},
        ]
        if gaussian_rot:
            l += [
                {'params': [self._rotation],
                    'lr': training_args.rotation_lr, "name": "rotation"}
            ]
        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)

        final_decay = training_args.position_lr_final / training_args.position_lr_init
        
        self.t_scheduler_args = get_expon_lr_func(lr_init=training_args.t_lr_init,
                                                    lr_final=training_args.t_lr_init * final_decay,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.max_iterations)

    def init_RT(self, pcd=None, pose=None):
        if pose is None:
            pose_init = torch.as_tensor(
                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]).cuda().requires_grad_(True)
            self.P = [LieGroupParameter(SE3(pose_init[None]))]
        else:
            quat = matrix_to_quaternion(pose[:3, :3])
            pose = torch.cat((pose[:3, 3], quat), -
                             1).cuda().requires_grad_(True)
            self.P = LieGroupParameter(SE3(pose[None]))

        self.rotate_xyz = True
        self.rotate_seq = False

    def init_RT_seq(self, seq_len, pose=None):
        if pose is None:
            pose_init = torch.as_tensor(
                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]).cuda().requires_grad_(True)
            self.P = [LieGroupParameter(SE3(pose_init[None]))
                      for _ in range(seq_len)] # 0-131: torch.size([1,6])

        else:
            quat = R.from_matrix(pose[..., :3, :3].numpy()).as_quat()
            pose = torch.cat((pose[..., :3, 3], torch.from_numpy(
                quat).float()), -1).cuda().requires_grad_(True)
            self.P = [LieGroupParameter(SE3(pose[idx][None]))
                      for idx in range(seq_len)]

        self.rotate_seq = True
        self.rotate_xyz = False

    def update_RT_seq(self, pose, idx):
        quat = matrix_to_quaternion(pose[:3, :3])
        quat = quat[..., [1, 2, 3, 0]]
        pose = torch.cat((pose[:3, 3], quat.float()), -
                         1).cuda().requires_grad_(True)
        self.P[idx] = LieGroupParameter(SE3(pose[None]))
        self.P[idx].group = SE3(pose[None])

    def update_learning_rate(self, iteration):
        ''' Learning rate scheduling per step '''
        for param_group in self.optimizer.param_groups:
            if param_group["name"] == "xyz":
                lr = self.xyz_scheduler_args(iteration)
                param_group['lr'] = lr
                return lr
            if param_group["name"] == "t":
                lr = self.t_scheduler_args(iteration)
                param_group['lr'] = lr

    def update_learning_rate_camera(self, cam_idx, iteration):
        ''' Learning rate scheduling per step '''
        if isinstance(self.camera_optimizer, list):
            for param_group in self.camera_optimizer[cam_idx].param_groups:
                lr = self.camera_scheduler_args(iteration)
                param_group['lr'] = lr
        else:
            for param_group in self.camera_optimizer.param_groups:
                lr = self.camera_scheduler_args(iteration)
                param_group['lr'] = lr

    def freeze_camera(self):
        for param_group in self.camera_optimizer.param_groups:
            param_group['lr'] = 0.0

    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l

    def save_ply(self, path, timestamp):
        mkdir_p(os.path.dirname(path))

        xyz = self.get_xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(
            1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().transpose(
            1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = self._opacity.detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
   
        dtype_full = [(attribute, 'f4')
                      for attribute in self.construct_list_of_attributes()]
        
        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate(
            (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

    def export_gaussian(self, ):
        gassuian = {}
        xyz = self._xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().contiguous().cpu().numpy()
        opacities = self._opacity.detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
        max_radii2D = self.max_radii2D.detach().cpu().numpy()
        t = self._t.detach().cpu().numpy()
        scale_t = self._scaling_t.detach().cpu().numpy()
        motion = self._motion.detach().cpu().numpy()
        omega = self._omega.detach().cpu().numpy()

        gassuian['xyz'] = xyz
        gassuian['normals'] = normals
        gassuian['f_dc'] = f_dc
        gassuian['f_rest'] = f_rest
        gassuian['opacities'] = opacities
        gassuian['scale'] = scale
        gassuian['rotation'] = rotation
        gassuian['max_radii2D'] = max_radii2D
        gassuian['t'] = t
        gassuian['scale_t'] = scale_t
        gassuian['motion'] = motion
        gassuian['omega'] = omega

        return gassuian

    def reset_opacity(self):
        opacities_new = inverse_sigmoid(
            torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
        optimizable_tensors = self.replace_tensor_to_optimizer(
            opacities_new, "opacity")
        self._opacity = optimizable_tensors["opacity"]

    def load_ply(self, path, num_gauss=-1):
        plydata = PlyData.read(path)

        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])),  axis=1)
        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

        extra_f_names = [
            p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
        extra_f_names = sorted(
            extra_f_names, key=lambda x: int(x.split('_')[-1]))
        assert len(extra_f_names) == 3*(self.max_sh_degree + 1) ** 2 - 3
        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
        for idx, attr_name in enumerate(extra_f_names):
            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
        features_extra = features_extra.reshape(
            (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))

        scale_names = [
            p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
        scale_names = sorted(scale_names, key=lambda x: int(x.split('_')[-1]))
        scales = np.zeros((xyz.shape[0], len(scale_names)))
        for idx, attr_name in enumerate(scale_names):
            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

        rot_names = [
            p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
        rot_names = sorted(rot_names, key=lambda x: int(x.split('_')[-1]))
        rots = np.zeros((xyz.shape[0], len(rot_names)))
        for idx, attr_name in enumerate(rot_names):
            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        if num_gauss == -1:
            num_gauss = xyz.shape[0]
        self._xyz = nn.Parameter(torch.tensor(
            xyz[:num_gauss], dtype=torch.float, device="cuda").requires_grad_(True))
        self._features_dc = nn.Parameter(torch.tensor(
            features_dc[:num_gauss], dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(torch.tensor(
            features_extra[:num_gauss], dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._opacity = nn.Parameter(torch.tensor(
            opacities[:num_gauss], dtype=torch.float, device="cuda").requires_grad_(True))
        self._scaling = nn.Parameter(torch.tensor(
            scales[:num_gauss], dtype=torch.float, device="cuda").requires_grad_(True))
        self._rotation = nn.Parameter(torch.tensor(
            rots[:num_gauss], dtype=torch.float, device="cuda").requires_grad_(True))

        R = torch.eye(3, device="cuda")
        T = torch.zeros(1, 3, device="cuda")
        self.R = nn.Parameter(R.requires_grad_(True))
        self.T = nn.Parameter(T.requires_grad_(True))
        self.active_sh_degree = self.max_sh_degree

        self.set_mask(
            torch.ones(
                self._opacity.shape[0],
                dtype=torch.bool,
                device="cuda",
                requires_grad=False,
            )
        )
        self.apply_grad_mask(self.mask)

    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if group["name"] == name:
                stored_state = self.optimizer.state.get(group['params'][0], None)
                stored_state["exp_avg"] = torch.zeros_like(tensor)
                stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group['params'][0], None)
            # if group["name"] not in ["xyz", "f_dc", "f_rest", "opacity", "scaling", "rotation"]:
            #     continue
            if stored_state is not None:
                stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(
                    (group["params"][0][mask].requires_grad_(True)))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    group["params"][0][mask].requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        self._motion = optimizable_tensors["motion"]
        self._omega = optimizable_tensors["omega"]

        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]

        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]

        self._t = optimizable_tensors['t']
        self._scaling_t = optimizable_tensors['scaling_t']
        self.t_gradient_accum = self.t_gradient_accum[valid_points_mask]

        self.mask = self.mask[valid_points_mask]

    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            # if group["name"] not in tensors_dict:
            #     continue
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:

                stored_state["exp_avg"] = torch.cat(
                    (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
                stored_state["exp_avg_sq"] = torch.cat(
                    (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(
                    torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(
                    torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling,
                              new_rotation, new_t, new_scaling_t, new_motion, new_omega):
        d = {"xyz": new_xyz,
             "f_dc": new_features_dc,
             "f_rest": new_features_rest,
             "opacity": new_opacities,
             "scaling": new_scaling,
             "rotation": new_rotation,
             "t": new_t,
             "scaling_t": new_scaling_t,
             "motion": new_motion,
             "omega": new_omega,}

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        # if "xyz" in optimizable_tensors:
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        # if "rotation" in optimizable_tensors:
        self._rotation = optimizable_tensors["rotation"]
        # self._rotation = optimizable_tensors["rotation"]
        self._t = optimizable_tensors['t']
        self._scaling_t = optimizable_tensors['scaling_t']
        self._motion = optimizable_tensors["motion"]
        self._omega = optimizable_tensors["omega"]
        self.t_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")

        self.xyz_gradient_accum = torch.zeros(
            (self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")


    def densify_and_split(self, grads, grad_threshold, scene_extent, grads_t, grad_t_threshold, N=2, time_split=False,
                          joint_sample=True):
        n_init_points = self.get_xyz.shape[0]
        # Extract points that satisfy the gradient condition
        padded_grad = torch.zeros((n_init_points), device="cuda")
        padded_grad[:grads.shape[0]] = grads.squeeze()
        selected_pts_mask = torch.where(
            padded_grad >= grad_threshold, True, False)

        if self.contract:
            scale_factor = self._xyz.norm(dim=-1)*scene_extent-1 # -0
            scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/scene_extent
        else:
            scale_factor = torch.ones_like(self._xyz)[:,0]/scene_extent

        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling,
                                                        dim=1).values > self.percent_dense * scene_extent*scale_factor)
        decay_factor = N*0.8
        if not self.no_time_split:
            N = N+1

        if time_split:
            padded_grad_t = torch.zeros((n_init_points), device="cuda")
            padded_grad_t[:grads_t.shape[0]] = grads_t.squeeze()
            selected_time_mask = torch.where(padded_grad_t >= grad_t_threshold, True, False)
            extend_thresh = self.percent_dense

            selected_time_mask = torch.logical_and(selected_time_mask,
                                                   torch.max(self.get_scaling_t, dim=1).values > extend_thresh)
            if joint_sample:
                selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask)

        stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
        means = torch.zeros((stds.size(0), 3), device="cuda")
        samples = torch.normal(mean=means, std=stds)
        rots = build_rotation(
            self._rotation[selected_pts_mask]).repeat(N, 1, 1)
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + \
            self.get_xyz[selected_pts_mask].repeat(N, 1)
        new_scaling = self.scaling_inverse_activation(
            self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8*N))
        new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
        new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
        new_features_rest = self._features_rest[selected_pts_mask].repeat(
            N, 1, 1)
        new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
        new_motion = self._motion[selected_pts_mask].repeat(N,1)
        new_omega = self._omega[selected_pts_mask].repeat(N,1)

        new_t = None
        new_scaling_t = None
        stds_t = self.get_scaling_t[selected_pts_mask].repeat(N, 1)
        means_t = torch.zeros((stds_t.size(0), 1), device="cuda")
        samples_t = torch.normal(mean=means_t, std=stds_t)
        new_t = samples_t+self.get_t[selected_pts_mask].repeat(N, 1)

        new_scaling_t = self.scaling_t_inverse_activation(
            self.get_scaling_t[selected_pts_mask].repeat(N, 1)/ (decay_factor))


        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation,
                                   new_t, new_scaling_t, new_motion, new_omega)

        new_mask = torch.cat([self.mask[selected_pts_mask]] * N, dim=0)
        self.mask = torch.cat([self.mask, new_mask], dim=0)

        prune_filter = torch.cat((selected_pts_mask, torch.zeros(
            N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
        self.prune_points(prune_filter)

    def densify_and_clone(self, grads, grad_threshold, scene_extent, grads_t, grad_t_threshold, time_clone=False):
        # Extract points that satisfy the gradient condition
        t_scale_factor=self.get_scaling_t.clamp(0,self.T)
        t_scale_factor=torch.exp(-t_scale_factor/self.T).squeeze()

        if self.contract:
            scale_factor = self._xyz.norm(dim=-1)*scene_extent-1
            scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/scene_extent
        else:
            scale_factor = torch.ones_like(self._xyz)[:,0]/scene_extent

        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,torch.max(self.get_scaling,dim=1).values <= self.percent_dense * scene_extent*scale_factor)
        if time_clone:
            selected_time_mask = torch.where(torch.norm(grads_t, dim=-1) >= grad_t_threshold, True, False)
            extend_thresh = self.percent_dense
            selected_time_mask = torch.logical_and(selected_time_mask,
                                                   torch.max(self.get_scaling_t, dim=1).values <= extend_thresh)
            selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask)

        new_xyz = self._xyz[selected_pts_mask]
        new_features_dc = self._features_dc[selected_pts_mask]
        new_features_rest = self._features_rest[selected_pts_mask]
        new_opacities = self._opacity[selected_pts_mask]
        new_scaling = self._scaling[selected_pts_mask]
        new_rotation = self._rotation[selected_pts_mask]
        new_t = None
        new_scaling_t = None
        new_t = self._t[selected_pts_mask]
        new_scaling_t = self._scaling_t[selected_pts_mask]
        new_motion = self._motion[selected_pts_mask]
        new_omega = self._omega[selected_pts_mask]

        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling,
                                   new_rotation, new_t, new_scaling_t, new_motion, new_omega)
        self.mask = torch.cat([self.mask, self.mask[selected_pts_mask]], dim=0)

    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, max_grad_t=None, prune_only=False):
        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0
        grads[~self.mask] = 0.0
        grads_t = self.t_gradient_accum / self.denom
        grads_t[grads_t.isnan()] = 0.0
        grads_t[~self.mask] = 0.0

        if self.t_grad:
            self.densify_and_clone(grads, max_grad, extent, grads_t, max_grad_t, time_clone=True)
            self.densify_and_split(grads, max_grad, extent, grads_t, max_grad_t, time_split=True)
        else:
            self.densify_and_clone(grads, max_grad, extent, grads_t, max_grad_t, time_clone=False)
            self.densify_and_split(grads, max_grad, extent, grads_t, max_grad_t, time_split=False)

        prune_mask = (self.get_opacity < min_opacity).squeeze()

        if self.contract:
            scale_factor = self._xyz.norm(dim=-1)*extent-1
            scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/extent
        else:
            scale_factor = torch.ones_like(self._xyz)[:,0]/extent

        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            big_points_ws = self.get_scaling.max(dim=1).values > self.big_point_threshold * extent * scale_factor  ## ori 0.1
            prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
        prune_mask = torch.logical_and(prune_mask, self.mask)  # fix bug
        self.prune_points(prune_mask)

        self.remove_grad_mask()
        self.apply_grad_mask(self.mask)

        torch.cuda.empty_cache()

    def add_densification_stats(self, viewspace_point_tensor, update_filter):
        # print(self.xyz_gradient_accum.shape, viewspace_point_tensor.shape, update_filter.shape)
        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True)
        self.denom[update_filter] += 1
        self.t_gradient_accum[update_filter] += self._t.grad.clone()[update_filter]

    def apply_weights(self, camera, weights, weights_cnt, image_weights):
        rasterizer = camera2rasterizer(
            camera, torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda")
        )
        rasterizer.apply_weights(
            self.get_xyz,
            None,
            self.get_opacity,
            None,
            weights,
            self.get_scaling,
            self.get_rotation,
            None,
            weights_cnt,
            image_weights,
        )

    def set_mask(self, mask):
        self.mask = mask

    def apply_grad_mask(self, mask):
        assert self.mask.shape[0] == self._xyz.shape[0]
        self.set_mask(mask)

        def hook(grad):
            final_grad = grad * (
                self.mask[:, None] if grad.ndim == 2 else self.mask[:, None, None]
            )
            # print(final_grad.abs().max())
            # print(final_grad.abs().mean())
            return final_grad

        fields = ["_xyz", "_features_dc", "_features_rest", "_opacity", "_scaling", "_t", '_scaling_t', '_omega', '_motion']

        self.hooks = []

        for field in fields:
            this_field = getattr(self, field)
            assert this_field.is_leaf and this_field.requires_grad
            self.hooks.append(this_field.register_hook(hook))

    def remove_grad_mask(self):
        # assert hasattr(self, "hooks")
        for hook in self.hooks:
            hook.remove()

        del self.hooks

    def get_near_gaussians_by_mask(self, mask, dist_thresh: float = 0.1):
        mask = mask.squeeze()
        object_xyz = self._xyz[mask]
        remaining_xyz = self._xyz[~mask]

        bbox_3D = torch.stack([torch.quantile(object_xyz[:, 0], 0.03), torch.quantile(object_xyz[:, 0], 0.97),
                               torch.quantile(object_xyz[:, 1], 0.03), torch.quantile(object_xyz[:, 1], 0.97),
                               torch.quantile(object_xyz[:, 2], 0.03), torch.quantile(object_xyz[:, 2], 0.97)])
        scale = bbox_3D[1::2] - bbox_3D[0::2]
        mid = (bbox_3D[1::2] + bbox_3D[0::2]) / 2
        scale *= 1.3
        bbox_3D[0::2] = mid - scale / 2
        bbox_3D[1::2] = mid + scale / 2

        in_bbox = (remaining_xyz[:, 0] >= bbox_3D[0]) & (remaining_xyz[:, 0] <= bbox_3D[1]) & \
                  (remaining_xyz[:, 1] >= bbox_3D[2]) & (remaining_xyz[:, 1] <= bbox_3D[3]) & \
                  (remaining_xyz[:, 2] >= bbox_3D[4]) & (remaining_xyz[:, 2] <= bbox_3D[5])
        in_box_remaining_xyz = remaining_xyz[in_bbox]

        _, _, nn_dist = K_nearest_neighbors(
            object_xyz, 1, query=in_box_remaining_xyz, return_dist=True
        )
        nn_dist = nn_dist.squeeze()
        valid_mask = (nn_dist <= dist_thresh)

        mask_to_update = torch.zeros_like(remaining_xyz[:, 0], dtype=torch.bool)
        true_indices = torch.nonzero(in_bbox)
        true_indices = true_indices[valid_mask, 0]
        mask_to_update[true_indices] = True
        # valid_remaining_idx = remaining_idx[valid_mask]

        return mask_to_update


class EditPro_Render:
    def __init__(self, sh_degree=3, white_background=False,
                 radius=1, view_dependent=False):

        self.sh_degree = sh_degree
        self.white_background = white_background
        self.radius = radius
        self.view_dependent = view_dependent

        self.gaussians = GaussianModel_4D(sh_degree,
                                         view_dependent=self.view_dependent)

        self.bg_color = torch.tensor(
            [1, 1, 1] if white_background else [0, 0, 0],
            dtype=torch.float32,
            device="cuda",
        )

    def init_model(self, input=None, num_pts=10000, radius=1.0):

        if input is None:
            # init from random points
            phis = np.random.random((num_pts,)) * 2 * np.pi
            costheta = np.random.random((num_pts,)) * 2 - 1
            thetas = np.arccos(costheta)
            mu = np.random.random((num_pts,))
            radius = radius * np.cbrt(mu)
            x = radius * np.sin(thetas) * np.cos(phis)
            y = radius * np.sin(thetas) * np.sin(phis)
            z = radius * np.cos(thetas)
            xyz = np.stack((x, y, z), axis=1)
            # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3

            shs = np.random.random((num_pts, 3)) / 255.0
            pcd = BasicPointCloud(
                points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
            )
            self.gaussians.create_from_pcd(pcd, 10)
            self.radius = radius.max()

        elif isinstance(input, BasicPointCloud):
            # load from a provided pcd
            radius = np.linalg.norm(input.points, axis=1).max()
            # TODO: check if this is correct with radius
            # self.gaussians.create_from_pcd(input, 1)
            self.gaussians.create_from_pcd(input, radius)
            self.radius = radius
        else:
            # load from saved ply
            self.gaussians.load_ply(input)

    def reset_model(self):
        self.gaussians = GaussianModel_4D(
            self.sh_degree, self.view_dependent)

    def render(
        self,
        viewpoint_camera,
        scaling_modifier=1.0,
        invert_bg_color=False,
        override_color=None,
        compute_cov3D_python=False,
        convert_SHs_python=False,
        time_shift=None, 
        other=[],
        mask=None,
    ):
        """
        Render the scene. 

        Background tensor (bg_color) must be on GPU!
        """

        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
        screenspace_points = (
            torch.zeros_like(
                self.gaussians.get_xyz,
                dtype=self.gaussians.get_xyz.dtype,
                requires_grad=True,
                device="cuda",
            )
            + 0
        )
        try:
            screenspace_points.retain_grad()
        except:
            pass

        # Set up rasterization configuration
        tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

        raster_settings = GaussianRasterizationSettings(
            image_height=int(viewpoint_camera.image_height),
            image_width=int(viewpoint_camera.image_width),
            tanfovx=tanfovx,
            tanfovy=tanfovy,
            bg=self.bg_color if not invert_bg_color else 1 - self.bg_color,
            scale_modifier=scaling_modifier,
            viewmatrix=viewpoint_camera.world_view_transform,
            projmatrix=viewpoint_camera.full_proj_transform,
            sh_degree=self.gaussians.active_sh_degree,
            campos=viewpoint_camera.camera_center,
            prefiltered=False,
            debug=False,
        )

        rasterizer = GaussianRasterizer(raster_settings=raster_settings)

        means3D = self.gaussians.get_xyz
        means2D = screenspace_points
        opacity = self.gaussians.get_opacity

        # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
        # scaling / rotation by the rasterizer.
        scales = None
        rotations = None
        cov3D_precomp = None
        means3D = self.gaussians.get_xyz_motion(viewpoint_camera.timestamp)
        marginal_t = self.gaussians.get_marginal_t(viewpoint_camera.timestamp)
        opacity = opacity * marginal_t
        if compute_cov3D_python:
            cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
        else:
            scales = self.gaussians.get_scaling
            rotations = self.gaussians.get_rotation_t(viewpoint_camera.timestamp)

        # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
        # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
        shs = None
        colors_precomp = None
        if colors_precomp is None:
            if convert_SHs_python:
                if self.view_dependent:
                    shs_view = self.gaussians.get_features.transpose(1, 2).view(
                        -1, 3, (self.gaussians.max_sh_degree + 1) ** 2
                    )
                    fidx = viewpoint_camera.uid
                    camera_center = self.gaussians.get_RT(fidx).inverse()[
                        :3, 3].detach()
                    camera_center = camera_center[None].repeat(
                        self.gaussians.get_features.shape[0], 1)
                    dir_pp = self.gaussians._xyz - camera_center
                    dir_pp_normalized = dir_pp / \
                        dir_pp.norm(dim=1, keepdim=True)
                    sh2rgb = eval_sh(
                        self.gaussians.active_sh_degree, shs_view, dir_pp_normalized
                    )
                    colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
                else:
                    colors_precomp = self.gaussians.get_features_noview
            else:
                shs = self.gaussians.get_features
        else:
            colors_precomp = override_color
    

        # Rasterize visible Gaussians to image, obtain their radii (on screen).
        out = rasterizer(
            means3D=means3D,
            means2D=means2D,
            shs=shs,
            colors_precomp=colors_precomp,
            opacities=opacity,
            scales=scales,
            rotations=rotations,
            cov3D_precomp=cov3D_precomp,
        )
        if len(out) == 4:
            rendered_image, radii, rendered_depth, rendered_alpha = out
            rendered_image = rendered_image.clamp(0, 1)

            # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
            # They will be excluded from value updates used in the splitting criteria.
            return {
                "image": rendered_image,
                "depth": rendered_depth,
                "alpha": rendered_alpha,
                "viewspace_points": screenspace_points,
                "visibility_filter": radii > 0,
                "radii": radii,
            }
        elif len(out) == 3:
            rendered_image, radii, rendered_depth = out

            rendered_image = rendered_image.clamp(0, 1)

            # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
            # They will be excluded from value updates used in the splitting criteria.
            return {
                "image": rendered_image,
                "depth": rendered_depth,
                "viewspace_points": screenspace_points,
                "visibility_filter": radii > 0,
                "radii": radii,
            }
    
    def render_flow(
        self,
        viewpoint_camera,
        time_delta,
        scaling_modifier=1.0,
        invert_bg_color=False,
        override_color=None,
        compute_cov3D_python=False,
        convert_SHs_python=False,
        time_shift=None, 
        other=[],
        mask=None,
    ):
        """
        Render the scene. 

        Background tensor (bg_color) must be on GPU!
        """

        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
        screenspace_points = (
            torch.zeros_like(
                self.gaussians.get_xyz,
                dtype=self.gaussians.get_xyz.dtype,
                requires_grad=True,
                device="cuda",
            )
            + 0
        )
        try:
            screenspace_points.retain_grad()
        except:
            pass

        # Set up rasterization configuration
        tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

        raster_settings = GaussianRasterizationSettings(
            image_height=int(viewpoint_camera.image_height),
            image_width=int(viewpoint_camera.image_width),
            tanfovx=tanfovx,
            tanfovy=tanfovy,
            bg=self.bg_color if not invert_bg_color else 1 - self.bg_color,
            scale_modifier=scaling_modifier,
            viewmatrix=viewpoint_camera.world_view_transform,
            projmatrix=viewpoint_camera.full_proj_transform,
            sh_degree=self.gaussians.active_sh_degree,
            campos=viewpoint_camera.camera_center,
            prefiltered=False,
            debug=False,
        )

        rasterizer = GaussianRasterizer(raster_settings=raster_settings)

        means3D = self.gaussians.get_xyz
        means2D = screenspace_points
        opacity = self.gaussians.get_opacity

        # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
        # scaling / rotation by the rasterizer.
        scales = None
        rotations = None
        cov3D_precomp = None

        means3D = self.gaussians.get_xyz_motion(viewpoint_camera.timestamp)
        marginal_t = self.gaussians.get_marginal_t(viewpoint_camera.timestamp)
        
        opacity = opacity * marginal_t
        if compute_cov3D_python:
            cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
        else:
            scales = self.gaussians.get_scaling
            rotations = self.gaussians.get_rotation_t(viewpoint_camera.timestamp)

        
        # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
        # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
        shs = None
        colors_precomp = None
    
        x_t1 = self.gaussians.get_xyz_motion(viewpoint_camera.timestamp)
        x_t2 = self.gaussians.get_xyz_motion(viewpoint_camera.timestamp + time_delta)

        flow = x_t2 - x_t1.detach()
        focal_y = int(viewpoint_camera.image_height) / (2.0 * tanfovy)
        focal_x = int(viewpoint_camera.image_width) / (2.0 * tanfovx)
        fidx = viewpoint_camera.uid
        viewmatrix = self.gaussians.get_RT(fidx).detach().cuda() #w2c
        # viewmatrix = viewpoint_camera.world_view_transform.cuda()

        x_t = torch.matmul(means3D, viewmatrix[:3, :3]) + viewmatrix[3, :3] 
        t = x_t.detach()
        flow[:, 0] = flow[:, 0] * focal_x / t[:, 2]  + flow[:, 2] * -(focal_x * t[:, 0]) / (t[:, 2]*t[:, 2])
        flow[:, 1] = flow[:, 1] * focal_y / t[:, 2]  + flow[:, 2] * -(focal_y * t[:, 1]) / (t[:, 2]*t[:, 2])

        colors_precomp = flow

        # Rasterize visible Gaussians to image, obtain their radii (on screen).
        out = rasterizer(
            means3D = means3D.detach(),
            means2D = means2D.detach(),
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity.detach(),
            scales = scales.detach(),
            rotations = rotations.detach(),
            cov3D_precomp = cov3D_precomp
        )
        if len(out) == 4:
            rendered_image, radii, rendered_depth, rendered_alpha = out

            # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
            # They will be excluded from value updates used in the splitting criteria.
            return {
                "image": rendered_image,
                "depth": rendered_depth,
                "alpha": rendered_alpha,
                "viewspace_points": screenspace_points,
                "visibility_filter": radii > 0,
                "radii": radii,
            }
        elif len(out) == 3:
            rendered_image, radii, rendered_depth = out

            # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
            # They will be excluded from value updates used in the splitting criteria.
            return {
                "image": rendered_image,
                "depth": rendered_depth,
                "viewspace_points": screenspace_points,
                "visibility_filter": radii > 0,
                "radii": radii,
            }
