The image data used for this project is published in https://www.pnas.org/content/113/51/E8238 with the corresponding repository at https://www.repository.cam.ac.uk/handle/1810/262530.
To demonstrate how cell segmentations can be encoded by spherical harmonic (SH) coefficients, an examplary cropped patch from the public data set is used.
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from skimage import io
from dipy.core.geometry import sphere2cart, cart2sphere
from keras.models import model_from_json
from util.utils import get_sampling_sphere, spherical_instance_sampling, sampling2harmonics, harmonics2sampling, \
agglomerative_clustering, harmonic_non_max_suppression
from util.data_handling import descriptors2image, instances2harmonicmask, harmonicmask2sampling
from models.networks import harmonic_rcnn_3D
# Load the image and the corresponding mask
img = io.imread('demo_img.tif')
seg = io.imread('demo_seg.tif')
# Reorder to have (x,y,z)
img = np.transpose(img, (2,1,0))
seg = np.transpose(seg, (2,1,0))
# Plot the image and mask
plt.subplot(1,2,1)
plt.imshow(img[...,32])
plt.subplot(1,2,2)
plt.imshow(seg[...,32])
plt.show()
To sample each instance and encode it to a compact representation, sample positions have to be specified. Those positions are chosen by following the principle of electrostatic repulsion and are determined in an iterative manner. For the proposed method, 5000 sample positions are optimized within 10000 iterations. For demonstration purposes, the acquisition of 200 angular positions within only 1000 iterations is shown in the following.
sampling_sphere = get_sampling_sphere(num_sample_points=200, num_iterations=1000, plot_sampling=True)
The final sampling pattern is derived from the $\theta$ and $\phi$ values of each spherical coordinate. Since the calculation of a more precice sampling grid would take too long, a pre-calculated set of 5000 sample positions is loaded.
# Load angular pattern
theta_phi_sampling = np.load('theta_phi_sampling_5000points_10000iter.npy')
# Sample each cell instance in the image crop
instances, sampled_radii, centroids = spherical_instance_sampling(seg, theta_phi_sampling, bg_values=[1])
For each of the instances, information about the instance label, the length of the sampled radius at each angular position and the location of the centroid were derived.
print('The first cell has label {0} and its centroid is located at {1}.\n'.format(instances[0], centroids[0])+\
'The sampled radii are a vector of length {0}, which corresponds to '.format(len(sampled_radii[0]))+\
'the total number of angular directions.')
Based on this sampling, spherical harmonic coefficients can be computed.
# Set up the converter
# First input is the SH order (order 8 = 81 coefficients)
s2h_converter = sampling2harmonics(8, theta_phi_sampling)
# Encode each instance segmentation
harmonic_sampling = s2h_converter.convert(sampled_radii)
harmonic_sampling[0]
To reverse the encoding and obtain voxelized segmentation masks, the harmonic representation has to be decoded using the radius sampling pattern, which ultimatly serves as input for a Delaunay triangulation.
# Set up the converter for decoding
h2s_converter = harmonics2sampling(8, theta_phi_sampling)
# Decode the instances
decoded_radii = h2s_converter.convert(harmonic_sampling)
# Assume certain predictions
prob_objects = [1,]*len(decoded_radii)
# Reconstruct the mask via Delaunay triangulation
decoded_mask = descriptors2image([centroids, prob_objects, decoded_radii], theta_phi_sampling=theta_phi_sampling, shape=(64,64,64))
# Plot the image and mask
plt.subplot(1,2,1)
plt.imshow(img[...,32])
plt.subplot(1,2,2)
plt.imshow(decoded_mask[...,32])
plt.show()
In order to use a CNN for prediction of cell shapes, a mask is constructed, which illustratively divides the input image into different segments and encodes the position, confidence and shape of possible cell instances in each of those segments. Considering the image patch, the resulting mask will be constructed in the following.
encoded_mask = instances2harmonicmask(seg[...,np.newaxis], s2h_converter, shape=(64,64,64), cell_size=(8,8,8), dets_per_region=2, bg_values=[1])
encoded_mask.shape
The segment size is $8\times8\times8$ pixel, which is constrained by the 3 conv layers with stride=2 in the network architecture. Since a patch size of $64\times64\times64$ pixel was used, the mask has a spatial size of $8\times8\times8$ pixel. Due to the employment of harmonic order 8, each shape is encoded using a total of 81 coefficients. With two possible detections per segment, 3 values for position regression within each segment and one value for the probability, this results in a total of $2\cdot(81+3+1)=170$ parameters per segment and, thus, per output voxel.
In the following a pre-trained network is used to predict those masks based on input images. Afterwards, predicted detections and shape are merged and post-processed to form the final predicted voxelized mask.
# Load the model architecture and the pre-trained weights
with open('models/model.json','r') as mh:
model = model_from_json(mh.read())
model.load_weights('models/model.hdf5')
# Normalize the patch
img_norm = img/255
# Predict the encoding maskarchitecturearchitecture
predicted_encoding = model.predict(img_norm[np.newaxis,..., np.newaxis])
# Filter and decode the mask into the radii sampling
# Since, for demonstration purposes, only one patch is processed, positional weighting of each detection
# is disabled and the threshold values are slightly adapted to not exclude possibly truncated shape
# predictions at the patch borders.
# Usually a suitable overlap of neigbouring patches is used to only consider certain segmentations.
pred_indices, pred_probs, pred_shapes = harmonicmask2sampling(predicted_encoding[0,...], \
h2s_converter=h2s_converter, \
cell_size=(8,8,8), \
dets_per_region=2, \
thresh=0.5, \
convert2radii=True, \
positional_weighting=False)
# Perform clustering (used for each patch individually to reduce further computation)
pred_indices, pred_probs, pred_shapes = agglomerative_clustering(pred_indices, pred_probs, shape_descriptors=pred_shapes, max_dist=10)
# Perform NMS (Usually used after each patch has been processed and after they are merged to the full-size image)
# Note that this is kind of overkill for a single patch, but reaches its full potential for merged full-size images.
pred_indices, pred_probs, pred_shapes = harmonic_non_max_suppression(pred_indices, pred_probs, pred_shapes, overlap_thresh=0.5)
# Reconstruct the voxelized segmentation mask
pred_mask = descriptors2image([pred_indices, pred_probs, pred_shapes], theta_phi_sampling=theta_phi_sampling, shape=(64,64,64), thresh=0.5)
# Plot the image and mask
plt.subplot(1,2,1)
plt.imshow(img[...,32])
plt.subplot(1,2,2)
plt.imshow(pred_mask[...,32])
plt.show()