Demonstration of the Synthesis Pipeline

Generating Synthetic Membrane Masks

In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from scipy.ndimage import distance_transform_edt
from util.synthetic_cell_membrane_masks import SyntheticCellMembraneMasks

plt.rc('font', size=18)         # controls default text sizes
plt.rc('axes', titlesize=12)    # fontsize of the axes title
plt.rc('axes', labelsize=12)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=12)   # fontsize of the tick labels
plt.rc('ytick', labelsize=12)   # fontsize of the tick labels
plt.rc('figure', titlesize=32)  # fontsize of the figure title

Before we can generate a synthetic mask, some parameters have to be specified. For demonstration purposes, the simulation is limited to a small specimen and, thus, a small number of cells. Cell density is estimated from the training set and was determined to be roughly one cell per $20\times20\times20$ pixel region.

In [2]:
## Set parameters

img_size = 128 # So far only quadratic images can be created
cell_density = 1/(20*20*20)
roundness_gamma = 5
sphere_radius = img_size//2
cell_count = np.int(4/3+np.pi*(sphere_radius**3)*cell_density)
In [3]:
## Set up the generator

# Any time this generator is called, a randomly shaped specimen is generated
membrane_generator = SyntheticCellMembraneMasks(size=img_size, radius=sphere_radius, roundness=roundness_gamma, cell_count=cell_count, sphere_count=1, use_cuda=True)
In [4]:
# Plot 2D crops of the results

plt.subplot(121)
plt.title('Instance mask')
plt.imshow(membrane_generator.label_image[img_size//2,:,:])
plt.subplot(122)
plt.title('Boundary of the instance mask')
plt.imshow(membrane_generator.labelBoundaries(membrane_generator.label_image)[img_size//2,:,:])
plt.show()

As exlpained in the paper, before generating the final instance mask, boundary morphology is enhanced by applying morphological opening. The change can be observed in the following plot:

In [5]:
plt.subplot(121)
plt.title('Initial boundary mask')
plt.imshow(membrane_generator.boundary_image[img_size//2,:,:])
plt.subplot(122)
plt.title('Enhanced boundary mask')
plt.imshow(membrane_generator.labelBoundaries(membrane_generator.label_image)[img_size//2,:,:])
plt.show()

Data Preparation

As input for the generative adversarial network, two distance maps have to be created. Those maps encode the position at each pixel by a tuple of distances from the outer boundary of the specimen to the centroid and to the background region, respectively. The resulting 3-channel image serves as input for the generator network.

In [6]:
# Get membrane mask and background label
instance_mask = membrane_generator.label_image
membrane_mask = membrane_generator.labelBoundaries(instance_mask)
background_label = np.unique(membrane_mask)[0]

# Calculate distance maps
dist_map1 = distance_transform_edt(instance_mask==background_label)
dist_map1 = np.max(dist_map1)-dist_map1 # Invert the distances
dist_map1[instance_mask!=background_label] = 0 # Remove foreground distances
dist_map2 = distance_transform_edt(instance_mask!=background_label)
dist_map2 = np.max(dist_map2)-dist_map2 # Invert the distances
dist_map2[instance_mask==background_label] = 0 # Remove background distances

# Concatenate the image
membrane_patch = np.concatenate((membrane_mask[np.newaxis,np.newaxis,...],\
                              dist_map1[np.newaxis,np.newaxis,...],\
                              dist_map2[np.newaxis,np.newaxis,...]), axis=1)


# Plot 2D slices
plt.subplot(131)
plt.title('Boundary mask')
plt.imshow(membrane_generator.labelBoundaries(membrane_mask)[img_size//2,:,:])
plt.subplot(132)
plt.title('Distance map 1')
plt.imshow(dist_map1[img_size//2,:,:])
plt.subplot(133)
plt.title('Distance map 2')
plt.imshow(dist_map2[img_size//2,:,:])
plt.show()

Generating Synthetic Microscopy Images

Those 3-channel images are translated into the image domain by a generative adversarial network.

In [7]:
# Include the model class
from models.gan_3D_model import GAN3D
from util.model_setup import load_config_json, numpy_to_torch_tensor

# Set up a config file
general_config, image_config, mask_config = load_config_json('models/gan_patchwise.json', globals())

# Set up the model and load pre-trained weights
model_synthesis = GAN3D(general_config)
model_synthesis.load_networks('latest')
initialize network with normal
loading the model from /home/staff/eschweiler/Publications/CVPR2020_HarmonicSegmentationSynthesis/Supplementary/Synthesis/models/latest_net_G.pth
In [8]:
# Since the generator network was trained to work in much bigger specimen,
# and the generation of large images would take too long for demonstration,
# an examplary image from the public data set is used

test_img = io.imread('demo_image.tif')
test_img = np.transpose(test_img, (2,1,0))

test_membranes = io.imread('demo_membrane.tif')
test_membranes = np.transpose(test_membranes, (2,1,0))

test_dist1 = io.imread('demo_dist1.tif')
test_dist1 = np.transpose(test_dist1, (2,1,0))

test_dist2 = io.imread('demo_dist2.tif')
test_dist2 = np.transpose(test_dist2, (2,1,0))
In [9]:
# Construct the input image
inputA = test_img[np.newaxis,np.newaxis,...]/255

inputB = np.concatenate((test_membranes[np.newaxis,np.newaxis,...],\
                         test_dist1[np.newaxis,np.newaxis,...],\
                         test_dist2[np.newaxis,np.newaxis,...]), axis=1)

# Plot the patch
plt.subplot(121)
plt.title('Real Patch')
plt.imshow(inputA[0,0,:,:,50])
plt.subplot(122)
plt.title('Real Mask')
plt.imshow(inputB[0,0,:,:,50])
plt.show()
In [10]:
# Predict the corresponding synthetic image 
model_synthesis.set_input([numpy_to_torch_tensor([inputA])[0], numpy_to_torch_tensor([inputB])[0]])
model_synthesis.test()
images = model_synthesis.get_current_visuals()
In [11]:
# Plot the generated image
# For the full-size image, a faded overlap technique and border cropping is used to 
# reduce border actifacts and obtain continuous intensities.
plt.subplot(121)
plt.title('Real Patch')
plt.imshow(images['real_B'].detach().cpu().numpy()[0,0,:,:,50])
plt.subplot(122)
plt.title('Fake Patch')
plt.imshow(images['fake_B'].detach().cpu().numpy()[0,0,:,:,50])
plt.show()