from matplotlib import pyplot as plt
plt.imsave('mask.png', out_mask[0][0].detach().cpu().numpy(), cmap='gray', vmax=1, vmin=0)
plt.imsave('masked_img0.png', masked_img[0].permute((1,2,0)).detach().cpu().numpy()/float(torch.max(masked_img[0]).detach().cpu()), vmax=1, vmin=0)
plt.imsave('ip_image.png', ip_images[0].permute((1,2,0)).detach().cpu().numpy()/255.0, vmax=1, vmin=0)