import math
import torchvision.transforms.functional as TVF
from .tensorBase import *


class TensorVMSplit(TensorBase):
    def __init__(self, aabb, gridSize, device, **kargs):
        super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)


    def init_svd_volume(self, res, device):
        self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
        self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
        self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)
        self.init_kernel(
            self.kernel_size, self.learnable_kernel, device, self.n_scales,
            self.density_n_comp, self.app_n_comp, self.kernel_init, init_sigma=1, init_scale=0.1
        )
        self.init_basis(
            self.learnable_basis, device, self.n_scales, self.density_n_comp,
            self.app_n_comp, self.basis_init, init_scale=0.1
        )

    def init_one_svd(self, n_component, gridSize, scale, device):
        plane_coef, line_coef = [], []
        for i in range(len(self.vecMode)):
            vec_id = self.vecMode[i]
            mat_id_0, mat_id_1 = self.matMode[i]
            plane_coef.append(torch.nn.Parameter(
                scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))
            )
            line_coef.append(
                torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))
            )

        return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)
    

    def init_kernel(
            self, kernel_size, learnable, device,
            n_scales, density_n_comp, app_n_comp,
            init_type, init_sigma=None, init_scale=None
        ):
        assert density_n_comp[0] == density_n_comp[1] == density_n_comp[2]
        assert app_n_comp[0] == app_n_comp[1] == app_n_comp[2]

        den_n_comps = density_n_comp[0]
        app_n_comps = app_n_comp[0]

        def gaussian(out_channels, kernel_sizes):
            sigma = init_sigma
            assert sigma is not None
            # The gaussian kernel is the product of the
            # gaussian function of each dimension.
            meshgrids = torch.meshgrid([
                torch.arange(s, dtype=torch.float32) for s in kernel_sizes
            ], indexing="ij")
            
            kernel = 1
            for size,  mgrid in zip(kernel_sizes, meshgrids):
                mean = (size - 1) / 2
                kernel *= 1 / (sigma * math.sqrt(2 * math.pi)) * \
                        torch.exp(-0.5 * ((mgrid - mean) / sigma)**2)
            
            # to make sum of kernel equal to 1
            kernel = kernel / torch.sum(kernel)

            return kernel.repeat(out_channels, 1, 1, 1)
        
        def random_normal(out_channels, kernel_sizes):
            scale = init_scale
            assert scale is not None
            return torch.randn((out_channels, 1, *kernel_sizes)) * scale
        
        def identity(out_channels, kernel_sizes):
            assert len(kernel_sizes) == 2
            kernel = torch.zeros(out_channels, 1, *kernel_sizes)
            kernel[..., kernel_sizes[0]//2, kernel_sizes[1]//2] = 1
            return kernel

        if init_type == "identity":
            f = identity
        elif init_type == "random_normal":
            f = random_normal
        elif init_type == "gaussian":
            f = gaussian
        else:
            raise ValueError("Invalid init_type")

        density_plane_kernels = [
            f(den_n_comps * n_scales, (kernel_size, kernel_size)).to(device) for _ in range(3)
        ]
        app_plane_kernels = [
            f(app_n_comps * n_scales, (kernel_size, kernel_size)).to(device) for _ in range(3)
        ]

        density_line_kernels = [
            f(den_n_comps * n_scales, (kernel_size, 1)).to(device) for _ in range(3)
        ]
        app_line_kernels = [
            f(app_n_comps * n_scales, (kernel_size, 1)).to(device) for _ in range(3)
        ]

        if learnable:
            self.density_plane_kernels = torch.nn.ParameterList([torch.nn.Parameter(k) for k in density_plane_kernels])
            self.app_plane_kernels = torch.nn.ParameterList([torch.nn.Parameter(k) for k in app_plane_kernels])
            self.density_line_kernels = torch.nn.ParameterList([torch.nn.Parameter(k) for k in density_line_kernels])
            self.app_line_kernels = torch.nn.ParameterList([torch.nn.Parameter(k) for k in app_line_kernels])
        else:
            self.density_plane_kernels = density_plane_kernels
            self.app_plane_kernels = app_plane_kernels
            self.density_line_kernels = density_line_kernels
            self.app_line_kernels = app_line_kernels

    def init_basis(
        self, learnable, device, n_scales, density_n_comp,
        app_n_comp, init_type, init_scale=None
    ):
        den_n_comps = density_n_comp[0]
        app_n_comps = app_n_comp[0]

        def arithmetic_mean(out_channels, in_channels):
            return torch.full((out_channels, in_channels, 1, 1), 1/in_channels)
        
        def random_normal(out_channels, in_channels):
            scale = init_scale
            assert scale is not None
            return torch.randn((out_channels, in_channels, 1, 1)) * scale
        
        if init_type == "arithmetic_mean":
            f = arithmetic_mean
        elif init_type == "random_normal":
            f = random_normal
        else:
            raise ValueError("Invalid init type")
        
        density_plane_basis = [f(den_n_comps * n_scales, den_n_comps).to(device) for _ in range(3)]
        app_plane_basis = [f(app_n_comps * n_scales, app_n_comps).to(device) for _ in range(3)]

        density_line_basis = [f(den_n_comps * n_scales, den_n_comps).to(device) for _ in range(3)]
        app_line_basis = [f(app_n_comps * n_scales, app_n_comps).to(device) for _ in range(3)]

        if learnable:
            self.density_plane_basis = torch.nn.ParameterList([torch.nn.Parameter(k) for k in density_plane_basis])
            self.app_plane_basis = torch.nn.ParameterList([torch.nn.Parameter(k) for k in app_plane_basis])
            self.density_line_basis = torch.nn.ParameterList([torch.nn.Parameter(k) for k in density_line_basis])
            self.app_line_basis = torch.nn.ParameterList([torch.nn.Parameter(k) for k in app_line_basis])
        else:
            self.density_plane_basis = density_plane_basis
            self.app_plane_basis = app_plane_basis
            self.density_line_basis = density_line_basis
            self.app_line_basis = app_line_basis

    def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
        grad_vars = [
            {'params': self.density_line, 'lr': lr_init_spatialxyz},
            {'params': self.density_plane, 'lr': lr_init_spatialxyz},
            {'params': self.app_line, 'lr': lr_init_spatialxyz},
            {'params': self.app_plane, 'lr': lr_init_spatialxyz},
            {'params': self.basis_mat.parameters(), 'lr':lr_init_network}
        ]
        
        if self.learnable_kernel:
            grad_vars += [
                {'params': self.density_plane_kernels.parameters(), 'lr': lr_init_network},
                {'params': self.density_line_kernels.parameters(), 'lr': lr_init_network},
                {'params': self.app_plane_kernels.parameters(), 'lr': lr_init_network},
                {'params': self.app_line_kernels.parameters(), 'lr': lr_init_network}
            ]
            if self.apply_basis:
                grad_vars += [
                    {'params': self.density_plane_basis.parameters(), 'lr': lr_init_network},
                    {'params': self.density_line_basis.parameters(), 'lr': lr_init_network},
                    {'params': self.app_plane_basis.parameters(), 'lr': lr_init_network},
                    {'params': self.app_line_basis.parameters(), 'lr': lr_init_network}
                ]

        if isinstance(self.renderModule, torch.nn.Module):
            grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
        
        return grad_vars


    def vectorDiffs(self, vector_comps):
        total = 0
        
        for idx in range(len(vector_comps)):
            n_comp, n_size = vector_comps[idx].shape[1:-1]
            
            dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
            non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
            total = total + torch.mean(torch.abs(non_diagonal))
        return total

    def vector_comp_diffs(self):
        return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)
    
    def density_L1(self):
        total = 0
        for idx in range(len(self.density_plane)):
            total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))
        return total
    
    def TV_loss_density(self, reg):
        total = 0
        for idx in range(len(self.density_plane)):
            total = total + reg(self.density_plane[idx]) * 1e-2 #+ reg(self.density_line[idx]) * 1e-3
        return total
        
    def TV_loss_app(self, reg):
        total = 0
        for idx in range(len(self.app_plane)):
            total = total + reg(self.app_plane[idx]) * 1e-2 #+ reg(self.app_line[idx]) * 1e-3
        return total
    
    def convolve_with_kernel(
            self, planes, lines, plane_kernels, line_kernels, n_components, reshape_and_transpose
        ):
        out_planes, out_lines = [], []
        for p, pk, l, lk, n_comp in zip(planes, plane_kernels, lines, line_kernels, n_components):
            # (1, n_comp, h, w) --> (1, n_comp*n_scales, h, w)
            p = p.repeat(1, self.n_scales, 1, 1)
            l = l.repeat(1, self.n_scales, 1, 1)
            # apply kernel
            if reshape_and_transpose:
                out_plane = F.conv2d(p, pk, padding="same", groups=self.n_scales*n_comp) \
                             .reshape(1, self.n_scales, -1, *p.shape[-2:]).transpose(1, 2)
                out_line = F.conv2d(l, lk, padding="same", groups=self.n_scales*n_comp) \
                            .reshape(1, self.n_scales, -1, *l.shape[-2:]).transpose(1, 2)
            else:
                out_plane = F.conv2d(p, pk, padding="same", groups=self.n_scales*n_comp)
                out_line = F.conv2d(l, lk, padding="same", groups=self.n_scales*n_comp)
            # non-lineaerity
            if self.kernel_nonlinear:
                out_plane = F.relu(out_plane)
                out_line = F.relu(out_line)
            out_planes.append(out_plane)
            out_lines.append(out_line)
        return out_planes, out_lines
    
    def convolve_with_basis(self, planes, lines, plane_basis, line_basis, n_components):
        out_planes, out_lines = [], []
        for p, pb, l, lb, n_comp in zip(planes, plane_basis, lines, line_basis, n_components):
            assert p.shape[1] == self.n_scales * n_comp
            # p: (1, n_scales*n_comp, h, w), l: (1, n_scales*n_comp, h, 1)
            out_planes.append(
                F.conv2d(p, pb, groups=self.n_scales) \
                 .reshape(1, self.n_scales, n_comp, *p.shape[-2:]) \
                 .transpose(1, 2)
            )
            out_lines.append(
                F.conv2d(l, lb, groups=self.n_scales) \
                 .reshape(1, self.n_scales, n_comp, *l.shape[-2:]) \
                 .transpose(1, 2)
            )
        return out_planes, out_lines


    def compute_densityfeature(self, xyz_sampled, scales):
        # xy, xz, yz coordinate: (3, #points, 2)
        coordinate_plane = torch.stack((
            xyz_sampled[..., self.matMode[0]],
            xyz_sampled[..., self.matMode[1]],
            xyz_sampled[..., self.matMode[2]]
        ))

        # z, y, x coordinate: (3, #points, 1)
        coordinate_line = torch.stack((
            xyz_sampled[..., self.vecMode[0]],
            xyz_sampled[..., self.vecMode[1]],
            xyz_sampled[..., self.vecMode[2]],
        )).view(3, -1, 1)
        
        # output tensor
        sigma_features = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)

        if self.apply_kernel is True and scales is not None:
            # xys, xzs, yzs: (3, #points, 1, 1, 3)
            coordinate_plane = torch.cat((
                coordinate_plane,
                torch.broadcast_to(scales, (3, -1, -1))
            ), dim=-1).view(3, -1, 1, 1, 3)

            # 0zs, 0ys, 0xs: (3, #points, 1, 1, 3)
            coordinate_line = torch.cat((
                torch.zeros_like(coordinate_line),
                coordinate_line,
                torch.broadcast_to(scales, (3, -1, -1))
            ), dim=-1).view(3, -1, 1, 1, 3)

            # convole with kernels
            planes, lines = self.convolve_with_kernel(
                self.density_plane, self.density_line, 
                self.density_plane_kernels, self.density_line_kernels,
                self.density_n_comp, not self.apply_basis
            )

            if self.apply_basis:
                planes, lines = self.convolve_with_basis(
                    planes, lines, self.density_plane_basis,
                    self.density_line_basis, self.density_n_comp
                )
        else:
            # xy, xz, yz: (3, #points, 1, 2)
            coordinate_plane = coordinate_plane.view(3, -1, 1, 2)

            # 0z, 0y, 0x: (3, #points, 1, 2)
            coordinate_line = torch.cat((
                torch.zeros_like(coordinate_line),
                coordinate_line
            ), dim=-1).view(3, -1, 1, 2)

            planes = self.density_plane
            lines = self.density_line
        
        # grid sample
        for i in range(len(planes)):
            plane_coef_point = F.grid_sample(
                planes[i], coordinate_plane[[i]], align_corners=True
            ).view(-1, *xyz_sampled.shape[:1])
            line_coef_point = F.grid_sample(
                lines[i], coordinate_line[[i]], align_corners=True
            ).view(-1, *xyz_sampled.shape[:1])
            sigma_features += torch.sum(plane_coef_point * line_coef_point, dim=0)

        return sigma_features

    def compute_appfeature(self, xyz_sampled, scales):
        # xy, xz, yz coordinate: (3, #points, 2)
        coordinate_plane = torch.stack((
            xyz_sampled[..., self.matMode[0]],
            xyz_sampled[..., self.matMode[1]],
            xyz_sampled[..., self.matMode[2]]
        ))

        # z, y, x coordinate: (3, #points, 1)
        coordinate_line = torch.stack((
            xyz_sampled[..., self.vecMode[0]],
            xyz_sampled[..., self.vecMode[1]],
            xyz_sampled[..., self.vecMode[2]],
        )).view(3, -1, 1)

        if self.apply_kernel is True:
            # xys, xzs, yzs: (3, #points, 1, 1, 3)
            coordinate_plane = torch.cat((
                coordinate_plane,
                torch.broadcast_to(scales, (3, -1, -1))
            ), dim=-1).view(3, -1, 1, 1, 3)

            # 0zs, 0ys, 0xs: (3, #points, 1, 1, 3)
            coordinate_line = torch.cat((
                torch.zeros_like(coordinate_line),
                coordinate_line,
                torch.broadcast_to(scales, (3, -1, -1))
            ), dim=-1).view(3, -1, 1, 1, 3)

            # convolve with kernels
            planes, lines = self.convolve_with_kernel(
                self.app_plane, self.app_line,
                self.app_plane_kernels, self.app_line_kernels,
                self.app_n_comp, not self.apply_basis
            )

            if self.apply_basis:
                planes, lines = self.convolve_with_basis(
                    planes, lines, self.app_plane_basis,
                    self.app_line_basis, self.app_n_comp,
                )
        else:
             # xy, xz, yz: (3, #points, 1, 2)
            coordinate_plane = coordinate_plane.view(3, -1, 1, 2)

            # 0z, 0y, 0x: (3, #points, 1, 2)
            coordinate_line = torch.cat((
                torch.zeros_like(coordinate_line),
                coordinate_line
            ), dim=-1).view(3, -1, 1, 2)

            planes = self.app_plane
            lines = self.app_line

        plane_coef_point = []
        line_coef_point = []
        for i in range(len(planes)):
            plane_coef_point.append(F.grid_sample(
                planes[i], coordinate_plane[[i]], align_corners=True
            ).view(-1, *xyz_sampled.shape[:1]))
            line_coef_point.append(F.grid_sample(
                lines[i], coordinate_line[[i]], align_corners=True
            ).view(-1, *xyz_sampled.shape[:1]))
        
        plane_coef_point = torch.cat(plane_coef_point)
        line_coef_point = torch.cat(line_coef_point)
        
        return self.basis_mat((plane_coef_point * line_coef_point).T)

    @torch.no_grad()
    def up_sampling_VM(self, plane_coef, line_coef, res_target):

        for i in range(len(self.vecMode)):
            vec_id = self.vecMode[i]
            mat_id_0, mat_id_1 = self.matMode[i]
            plane_coef[i] = torch.nn.Parameter(
                F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
                              align_corners=True))
            line_coef[i] = torch.nn.Parameter(
                F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))

        return plane_coef, line_coef

    @torch.no_grad()
    def upsample_volume_grid(self, res_target):
        self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
        self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)

        self.update_stepSize(res_target)
        print(f'upsamping to {res_target}')

    @torch.no_grad()
    def shrink(self, new_aabb):
        print("====> shrinking ...")
        xyz_min, xyz_max = new_aabb
        t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
        # print(new_aabb, self.aabb)
        # print(t_l, b_r,self.alphaMask.alpha_volume.shape)
        t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
        b_r = torch.stack([b_r, self.gridSize]).amin(0)

        for i in range(len(self.vecMode)):
            mode0 = self.vecMode[i]
            self.density_line[i] = torch.nn.Parameter(
                self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
            )
            self.app_line[i] = torch.nn.Parameter(
                self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
            )
            mode0, mode1 = self.matMode[i]
            self.density_plane[i] = torch.nn.Parameter(
                self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
            )
            self.app_plane[i] = torch.nn.Parameter(
                self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
            )


        if not torch.all(self.alphaMask.gridSize == self.gridSize):
            t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
            correct_aabb = torch.zeros_like(new_aabb)
            correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
            correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
            print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
            new_aabb = correct_aabb

        newSize = b_r - t_l
        self.aabb = new_aabb
        self.update_stepSize((newSize[0], newSize[1], newSize[2]))


class TensorCP(TensorBase):
    def __init__(self, aabb, gridSize, device, **kargs):
        super(TensorCP, self).__init__(aabb, gridSize, device, **kargs)


    def init_svd_volume(self, res, device):
        self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device)
        self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device)
        self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device)


    def init_one_svd(self, n_component, gridSize, scale, device):
        line_coef = []
        for i in range(len(self.vecMode)):
            vec_id = self.vecMode[i]
            line_coef.append(
                torch.nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1))))
        return torch.nn.ParameterList(line_coef).to(device)

    
    def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
        grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
                     {'params': self.app_line, 'lr': lr_init_spatialxyz},
                     {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
        if isinstance(self.renderModule, torch.nn.Module):
            grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
        return grad_vars

    def compute_densityfeature(self, xyz_sampled):

        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)


        line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1])
        line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]],
                                        align_corners=True).view(-1, *xyz_sampled.shape[:1])
        line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]],
                                        align_corners=True).view(-1, *xyz_sampled.shape[:1])
        sigma_feature = torch.sum(line_coef_point, dim=0)
        
        
        return sigma_feature
    
    def compute_appfeature(self, xyz_sampled):

        coordinate_line = torch.stack(
            (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)


        line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1])
        line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]],
                                                          align_corners=True).view(-1, *xyz_sampled.shape[:1])
        line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]],
                                                          align_corners=True).view(-1, *xyz_sampled.shape[:1])

        return self.basis_mat(line_coef_point.T)
    

    @torch.no_grad()
    def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target):

        for i in range(len(self.vecMode)):
            vec_id = self.vecMode[i]
            density_line_coef[i] = torch.nn.Parameter(
                F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
            app_line_coef[i] = torch.nn.Parameter(
                F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))

        return density_line_coef, app_line_coef

    @torch.no_grad()
    def upsample_volume_grid(self, res_target):
        self.density_line, self.app_line = self.up_sampling_Vector(self.density_line, self.app_line, res_target)

        self.update_stepSize(res_target)
        print(f'upsamping to {res_target}')

    @torch.no_grad()
    def shrink(self, new_aabb):
        print("====> shrinking ...")
        xyz_min, xyz_max = new_aabb
        t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units

        t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
        b_r = torch.stack([b_r, self.gridSize]).amin(0)


        for i in range(len(self.vecMode)):
            mode0 = self.vecMode[i]
            self.density_line[i] = torch.nn.Parameter(
                self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
            )
            self.app_line[i] = torch.nn.Parameter(
                self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
            )

        if not torch.all(self.alphaMask.gridSize == self.gridSize):
            t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
            correct_aabb = torch.zeros_like(new_aabb)
            correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
            correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
            print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
            new_aabb = correct_aabb

        newSize = b_r - t_l
        self.aabb = new_aabb
        self.update_stepSize((newSize[0], newSize[1], newSize[2]))

    def density_L1(self):
        total = 0
        for idx in range(len(self.density_line)):
            total = total + torch.mean(torch.abs(self.density_line[idx]))
        return total

    def TV_loss_density(self, reg):
        total = 0
        for idx in range(len(self.density_line)):
            total = total + reg(self.density_line[idx]) * 1e-3
        return total

    def TV_loss_app(self, reg):
        total = 0
        for idx in range(len(self.app_line)):
            total = total + reg(self.app_line[idx]) * 1e-3
        return total