render_indices.label = 'render_indices'
means3D.label = "means3D"
opacity.label = "opacity"
scales.label = "scales" 
rotations.label = "rotations" 
features_dc.label = "features_dc"
features_rest.label = "features_rest"
shs.label = 'shs'
means3D.grad.label = "means3Dgrad"
opacity.grad.label = "opacitygrad"
scales.grad.label = "scalesgrad" 
rotations.grad.label = "rotationsgrad" 
features_dc.grad.label = "features_dcgrad"
features_rest.grad.label = "features_restgrad"
for param in parameters: 
	param['exp_avgs'].label = param['name'] + '_exp_avgs'
	param['exp_avgs_sqs'].label = param['name'] + '_exp_avgs_sqs'
SPT_counts_new.label = "SPT_counts_new"
SPT_counts.label = "SPT_counts"
write_back_mask.label = "write_back_mask"
SPT_indices.label = "SPT_indices"
keep_SPT_indices.label = "keep_SPT_indices"
SPTs_prev_to_new.label = "SPTs_prev_to_new"
SPT_distances.label = "SPT_distances"
SPT_node_indices.label = "SPT_node_indices"
LOD_detail_cut.label = "LOD_detail_cut"
coarse_cut.label = "coarse_cut"
leaf_mask.label = "leaf_mask"
leaf_nodes.label = "leaf_nodes"
SPTs_prev_to_new.label = "SPTs_prev_to_new"
valid.label = "valid"
prev_distances_compare.label = "prev_distances_compare"
distances_compare.label = "distances_compare"
close_enough.label = "close_enough"
valid_non_zero.label = "valid_non_zero"
close_enough_non_zero.label = "close_enough_non_zero"
SPT_keep_counts_indices.label = "SPT_keep_counts_indices"
keep_gaussians_mask.label = "keep_gaussians_mask"
mask.label = "mask"
upper_tree_nodes_to_render.label = "upper_tree_nodes_to_render"
cut_SPTs.label = "cut_SPTs"
SPT_counts.label = "SPT_counts"
SPT_counts_new.label = "SPT_counts_new"
prev_SPT_indices.label = "prev_SPT_indices"
prev_SPT_distances.label = "prev_SPT_distances"
prev_SPT_counts.label = "prev_SPT_counts"
gaussians.SPT_starts.label = "SPT_starts"
gaussians.SPT_min.label = "SPT_min"
gaussians.SPT_max.label = "SPT_max"
gaussians.SPT_gaussian_indices.label = "SPT_gaussian_indices" 
gaussians.upper_tree_nodes.label = "upper_tree_nodes"
gaussians.upper_tree_xyz.label = "upper_tree_xyz"
gaussians.upper_tree_scaling.label = "upper_tree_scaling"
gaussians.min_distance_squared.label = "min_distance_squared"
gt_image.label = "gt_image"
bounds.label = "bounds"
planes.label = "planes"
render_pkg["render"].label = "image"
render_pkg["viewspace_points"].label = "means2D"
render_pkg["viewspace_points"].grad.label = "means2D_grad"
gaussians.bounding_sphere_radii.label = "bounding_sphere_radii"

import gc
import sys
gc.collect()
torch.cuda.empty_cache()
unknown_size = 0
known_size = 0
known_tensors = []
unknown_tensors = []
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            if obj.is_cuda and obj.nelement() * obj.element_size() > 64:
                if hasattr(obj, 'label'):  
                    known_size += obj.nelement() * obj.element_size()
                    known_tensors.append(obj)
                else:
                    unknown_tensors.append(obj)
                    unknown_size += obj.nelement() * obj.element_size()
    except:
        pass
known_tensors = sorted(known_tensors, key=lambda x : x.nelement() * x.element_size(), reverse=True)
labels = []
for obj in known_tensors:
    print(obj.label, end=":   ")
    if obj.label in labels:
    	print(" (DUPLICATE) ", end="")
    print(type(obj), obj.size(), obj.dtype, f" {obj.nelement() * obj.element_size():_} ", sys.getrefcount(obj))
    labels.append(obj.label)
unknown_tensors = sorted(unknown_tensors, key=lambda x : x.nelement() * x.element_size(), reverse=True)
for obj in unknown_tensors:
    print("UNKNOWN", end=":   ")
    print(type(obj), obj.size(), obj.dtype, f" {obj.nelement() * obj.element_size():_} ", sys.getrefcount(obj))
# make sure objects can get freed
known_tensors = []
unknown_tensors = []

print(f"Known  : {known_size:_}")
print(f"Unknown: {unknown_size:_}")
