In [1]:
# Uncomment line below to install exlib
# !pip install exlib
In [ ]:
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

import sys; sys.path.append("../../src")
import exlib
from exlib.datasets.chestx import *
from exlib.features.vision import *

Overview¶

  • The objective is to predict the pathology regions (atelectasis, calcification, etc. Total 13).
  • The higher-level features are the anatomical structures (left clavicle, right clavicle, etc. Total 14).
In [3]:
dataset = ChestXDataset(split="test")

Dataset samples¶

Primary task: predict where the pathologies are

In [4]:
torch.manual_seed(105)
for i in torch.randperm(len(dataset)):
    sample = dataset[i.item()]
    image, pathols, structs = sample["image"], sample["pathols"], sample["structs"]
    if pathols.sum() > 2 and structs.sum() > 0:
        break
In [5]:
plt.clf()
plt.imshow(image.numpy().transpose(1,2,0), cmap="gray")
plt.axis("off")

print("Pathologies present:")
for idx in pathols.nonzero():
    print(dataset.pathology_names[idx.item()])
Pathologies present:
Atelectasis
Consolidation
Effusion
No description has been provided for this image

Expert-specified higher-level features: anatomical structures¶

Identifying where key structures are is important!

In [6]:
plt.clf()
fig, ax = plt.subplot_mosaic([
    (["image"] + [f"struct{i}t" for i in range(7)]),
    (["."] + [f"struct{i}t" for i in range(7,14)]),
], figsize=(14,4))

for _, a in ax.items(): a.set_axis_off()
struct_titles = ChestXDataset.structure_names

ax["image"].imshow(image.numpy().transpose(1,2,0), cmap="gray")
ax["image"].set_title("Image")

for i in range(14):
    mask_t = structs[i].unsqueeze(0)
    ax[f"struct{i}t"].imshow(((image * mask_t).numpy().transpose(1,2,0)) * 2, cmap="gray")
    ax[f"struct{i}t"].set_title(f"{struct_titles[i][:10]} (T)", fontsize=10)
<Figure size 640x480 with 0 Axes>
No description has been provided for this image

How well are the higher-level feature alignments?¶

In [7]:
all_baseline_scores = get_chestx_scores(N=100, batch_size=8)
100%|██████████████████████████████████████████████████████████████████████████████████| 13/13 [08:06<00:00, 37.41s/it]
In [8]:
for name, scores in all_baseline_scores.items():
    print(f'BASELINE {name} mean score: {scores.mean()}')
BASELINE identity mean score: 0.21211296319961548
BASELINE random mean score: 0.04255373775959015
BASELINE patch mean score: 0.09955804795026779
BASELINE quickshift mean score: 0.3367842137813568
BASELINE watershed mean score: 0.14116907119750977
BASELINE sam mean score: 0.30522453784942627
BASELINE ace mean score: 0.2603427469730377
BASELINE craft mean score: 0.11851011216640472
BASELINE archipelago mean score: 0.2134942263364792
In [ ]: