import sys
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from scipy.ndimage import distance_transform_edt
from util.synthetic_cell_membrane_masks import SyntheticCellMembraneMasks
plt.rc('font', size=18) # controls default text sizes
plt.rc('axes', titlesize=12) # fontsize of the axes title
plt.rc('axes', labelsize=12) # fontsize of the x and y labels
plt.rc('xtick', labelsize=12) # fontsize of the tick labels
plt.rc('ytick', labelsize=12) # fontsize of the tick labels
plt.rc('figure', titlesize=32) # fontsize of the figure title
Before we can generate a synthetic mask, some parameters have to be specified. For demonstration purposes, the simulation is limited to a small specimen and, thus, a small number of cells. Cell density is estimated from the training set and was determined to be roughly one cell per $20\times20\times20$ pixel region.
## Set parameters
img_size = 128 # So far only quadratic images can be created
cell_density = 1/(20*20*20)
roundness_gamma = 5
sphere_radius = img_size//2
cell_count = np.int(4/3+np.pi*(sphere_radius**3)*cell_density)
## Set up the generator
# Any time this generator is called, a randomly shaped specimen is generated
membrane_generator = SyntheticCellMembraneMasks(size=img_size, radius=sphere_radius, roundness=roundness_gamma, cell_count=cell_count, sphere_count=1, use_cuda=True)
# Plot 2D crops of the results
plt.subplot(121)
plt.title('Instance mask')
plt.imshow(membrane_generator.label_image[img_size//2,:,:])
plt.subplot(122)
plt.title('Boundary of the instance mask')
plt.imshow(membrane_generator.labelBoundaries(membrane_generator.label_image)[img_size//2,:,:])
plt.show()
As exlpained in the paper, before generating the final instance mask, boundary morphology is enhanced by applying morphological opening. The change can be observed in the following plot:
plt.subplot(121)
plt.title('Initial boundary mask')
plt.imshow(membrane_generator.boundary_image[img_size//2,:,:])
plt.subplot(122)
plt.title('Enhanced boundary mask')
plt.imshow(membrane_generator.labelBoundaries(membrane_generator.label_image)[img_size//2,:,:])
plt.show()
As input for the generative adversarial network, two distance maps have to be created. Those maps encode the position at each pixel by a tuple of distances from the outer boundary of the specimen to the centroid and to the background region, respectively. The resulting 3-channel image serves as input for the generator network.
# Get membrane mask and background label
instance_mask = membrane_generator.label_image
membrane_mask = membrane_generator.labelBoundaries(instance_mask)
background_label = np.unique(membrane_mask)[0]
# Calculate distance maps
dist_map1 = distance_transform_edt(instance_mask==background_label)
dist_map1 = np.max(dist_map1)-dist_map1 # Invert the distances
dist_map1[instance_mask!=background_label] = 0 # Remove foreground distances
dist_map2 = distance_transform_edt(instance_mask!=background_label)
dist_map2 = np.max(dist_map2)-dist_map2 # Invert the distances
dist_map2[instance_mask==background_label] = 0 # Remove background distances
# Concatenate the image
membrane_patch = np.concatenate((membrane_mask[np.newaxis,np.newaxis,...],\
dist_map1[np.newaxis,np.newaxis,...],\
dist_map2[np.newaxis,np.newaxis,...]), axis=1)
# Plot 2D slices
plt.subplot(131)
plt.title('Boundary mask')
plt.imshow(membrane_generator.labelBoundaries(membrane_mask)[img_size//2,:,:])
plt.subplot(132)
plt.title('Distance map 1')
plt.imshow(dist_map1[img_size//2,:,:])
plt.subplot(133)
plt.title('Distance map 2')
plt.imshow(dist_map2[img_size//2,:,:])
plt.show()
Those 3-channel images are translated into the image domain by a generative adversarial network.
# Include the model class
from models.gan_3D_model import GAN3D
from util.model_setup import load_config_json, numpy_to_torch_tensor
# Set up a config file
general_config, image_config, mask_config = load_config_json('models/gan_patchwise.json', globals())
# Set up the model and load pre-trained weights
model_synthesis = GAN3D(general_config)
model_synthesis.load_networks('latest')
# Since the generator network was trained to work in much bigger specimen,
# and the generation of large images would take too long for demonstration,
# an examplary image from the public data set is used
test_img = io.imread('demo_image.tif')
test_img = np.transpose(test_img, (2,1,0))
test_membranes = io.imread('demo_membrane.tif')
test_membranes = np.transpose(test_membranes, (2,1,0))
test_dist1 = io.imread('demo_dist1.tif')
test_dist1 = np.transpose(test_dist1, (2,1,0))
test_dist2 = io.imread('demo_dist2.tif')
test_dist2 = np.transpose(test_dist2, (2,1,0))
# Construct the input image
inputA = test_img[np.newaxis,np.newaxis,...]/255
inputB = np.concatenate((test_membranes[np.newaxis,np.newaxis,...],\
test_dist1[np.newaxis,np.newaxis,...],\
test_dist2[np.newaxis,np.newaxis,...]), axis=1)
# Plot the patch
plt.subplot(121)
plt.title('Real Patch')
plt.imshow(inputA[0,0,:,:,50])
plt.subplot(122)
plt.title('Real Mask')
plt.imshow(inputB[0,0,:,:,50])
plt.show()
# Predict the corresponding synthetic image
model_synthesis.set_input([numpy_to_torch_tensor([inputA])[0], numpy_to_torch_tensor([inputB])[0]])
model_synthesis.test()
images = model_synthesis.get_current_visuals()
# Plot the generated image
# For the full-size image, a faded overlap technique and border cropping is used to
# reduce border actifacts and obtain continuous intensities.
plt.subplot(121)
plt.title('Real Patch')
plt.imshow(images['real_B'].detach().cpu().numpy()[0,0,:,:,50])
plt.subplot(122)
plt.title('Fake Patch')
plt.imshow(images['fake_B'].detach().cpu().numpy()[0,0,:,:,50])
plt.show()