Welcome to FIX!¶

In this notebook, we give some examples of loading dataset and running benchmarks.

Part 1: Loading Datasets¶

We'll first talk about how to load the datasets described in our paper. FIX is built using the exlib library, which we load using a local version for now.

In [ ]:
# !pip install exlib
import sys; sys.path.insert(0, "src")
import exlib
import torch
import torch.nn.functional as F
import torchvision
import datasets as huggingface_datasets # Not to be confused with exlib.datasets!
import matplotlib.pyplot as plt
import numpy as np

Cholecystectomy Dataset¶

This dataset contains image data from cholecystectomy surgery (gallbladder removal). The fields are as follows:

  • image: A is of the surgery.
  • gonogo: Where it is safe or unsafe to operate. Background (0), safe (1), and unsafe (2).
  • organs: Relevant organ structures for surgery. Background (0), liver (1), gallbladder (2), and hepatocystic triangle (3). These are the expert-specified interpretable features.
In [2]:
cholec_dataset = exlib.datasets.cholec.CholecDataset(split="test")
cholec_item = cholec_dataset[0]
cholec_image = cholec_item["image"]
cholec_gonogo = cholec_item["gonogo"]
cholec_organs = cholec_item["organs"]

print("image:", cholec_image.shape, cholec_image.dtype)
print("gonogo:", cholec_gonogo.shape, cholec_gonogo.dtype)
print("organs:", cholec_organs.shape, cholec_organs.dtype)
image: torch.Size([3, 360, 640]) torch.float32
gonogo: torch.Size([360, 640]) torch.int64
organs: torch.Size([360, 640]) torch.int64
In [3]:
plt.clf()
fig, ax = plt.subplots(1, 3, figsize=(12,4))
for i in range(3): ax[i].set_axis_off()

cholec_image = torchvision.transforms.GaussianBlur(41, (50.0, 50.0))(cholec_image)

ax[0].imshow(cholec_image.numpy().transpose(1,2,0))
ax[1].imshow(cholec_gonogo.numpy())
ax[2].imshow(cholec_organs.numpy())

ax[0].set_title("Original Image (blurred)")
ax[1].set_title("Back (0), Safe (1), Unsafe (2)")
ax[2].set_title("Back (0), Liver (1), Gallb. (2), Hept. (3)")
Out[3]:
Text(0.5, 1.0, 'Back (0), Liver (1), Gallb. (2), Hept. (3)')
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [ ]:
 

Chest X-ray Dataset¶

This dataset contains vision data for chest X-ray pathology identification. The fields are as follows:

  • image: The image of the chest X-ray.
  • pathols: A binary vector that denotes which of the 14 pathologies are present.
  • struct: A collection of binary masks over the image for the relevant anatomical structures. These are the expert-specified interpretable features.
In [4]:
chestx_dataset = exlib.datasets.chestx.ChestXDataset(split="test")

# Find an image that has at least 2 pathologies present
for i in range(10000):
    chestx_item = chestx_dataset[i]
    if chestx_item["pathols"].sum() > 1: break

chestx_image = chestx_item["image"]
chestx_pathols = chestx_item["pathols"]
chestx_structs = chestx_item["structs"]

print("image:", chestx_image.shape, chestx_image.dtype)
print("pathols:", chestx_pathols.shape, chestx_pathols.dtype)
print("structs:", chestx_structs.shape, chestx_structs.dtype)
image: torch.Size([1, 224, 224]) torch.float32
pathols: torch.Size([14]) torch.int64
structs: torch.Size([14, 224, 224]) torch.int64
In [5]:
print("All pathologies:")
print(", ".join([f"({i}) {s}" for i,s in enumerate(chestx_dataset.pathology_names)]))

print("\nAll structures:")
print(", ".join([f"({i}) {s}" for i,s in enumerate(chestx_dataset.structure_names)]))
All pathologies:
(0) Atelectasis, (1) Cardiomegaly, (2) Consolidation, (3) Edema, (4) Effusion, (5) Emphysema, (6) Fibrosis, (7) Hernia, (8) Infiltration, (9) Mass, (10) Nodule, (11) Pleural_Thickening, (12) Pneumonia, (13) Pneumothorax

All structures:
(0) Left Clavicle, (1) Right Clavicle, (2) Left Scapula, (3) Right Scapula, (4) Left Lung, (5) Right Lung, (6) Left Hilus Pulmonis, (7) Right Hilus Pulmonis, (8) Heart, (9) Aorta, (10) Facies Diaphragmatica, (11) Mediastinum, (12) Weasand, (13) Spine
In [6]:
plt.clf()
fig, ax = plt.subplot_mosaic([
    (["image"] + [f"struct{i}" for i in range(7)]),
    (["."] + [f"struct{i}" for i in range(7,14)]),
], figsize=(14,4))

for _, a in ax.items(): a.set_axis_off()

struct_titles = chestx_dataset.structure_names
ax["image"].imshow(chestx_image.numpy().transpose(1,2,0), cmap="gray")
ax["image"].set_title("Image")

for i in range(14):
    mask_t = chestx_structs[i].unsqueeze(0)
    ax[f"struct{i}"].imshow((mask_t.numpy().transpose(1,2,0)) * 2, cmap="gray")
    ax[f"struct{i}"].set_title(f"{struct_titles[i][:10]} (T)", fontsize=10)
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [7]:
print(f"Pathologies present:")
for idx in chestx_pathols.nonzero():
    print("*", chestx_dataset.pathology_names[idx.item()])
Pathologies present:
* Cardiomegaly
* Infiltration
In [ ]:
 

Cosmological Mass Maps Dataset¶

This dataset contains clean simulated weak lensing maps without noise. The relevant fields are as follows:

  • input: A (1,66,66)-shaped weak lensing map
  • label: A pair of numbers that represents the cosmological parameters Omega_m and sigma_8. In this dataset, the expert-specified features are absent.
In [8]:
massmaps_dataset = huggingface_datasets.load_dataset(exlib.datasets.mass_maps.DATASET_REPO, split="test")
In [9]:
# Plot a few examples of mass maps
plt.clf()
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(len(ax)):
    mm_input = torch.tensor(massmaps_dataset[i]["input"])
    mm_label = massmaps_dataset[i]["label"]
    ax[i].imshow(mm_input.numpy().transpose(1,2,0))
    ax[i].set_axis_off()
    ax[i].set_title(f"Omega_m = {mm_label[0]:.3f},  sigma_8 = {mm_label[1]:.3f}")   
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [ ]:
 

Emotions Dataset¶

This dataset contains 58k carefully curated Reddit comments labeled for 27 emotion categories or Neutral. The fields are as follows:

  • text: The reddit comment.
  • labels: The emotion annotations.
  • comment_id: Unique identifier of the comment.
In [10]:
emotion_dataset = exlib.datasets.emotion.EmotionDataset(split="test")
In [11]:
emotion_dataloader = torch.utils.data.DataLoader(emotion_dataset, batch_size=4, shuffle=False)
emotion_model = exlib.datasets.emotion.EmotionClassifier().eval()

for batch in emotion_dataloader: 
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    output = emotion_model(input_ids, attention_mask)
    utterances = [
        emotion_dataset.tokenizer.decode(input_id, skip_special_tokens=True)
        for input_id in input_ids
    ]
    for utterance, label in zip(utterances, output.logits):
        id_str = emotion_model.model.config.id2label[label.argmax().item()]
        print("Text: {}\nEmotion: {}\n".format(utterance, id_str))
    break
SamLowe/roberta-base-go_emotions
Text: I’m really sorry about your situation :( Although I love the names Sapphira, Cirilla, and Scarlett!
Emotion: remorse

Text: It's wonderful because it's awful. At not with.
Emotion: admiration

Text: Kings fan here, good luck to you guys! Will be an interesting game to watch! 
Emotion: optimism

Text: I didn't know that, thank you for teaching me something today!
Emotion: gratitude

In [ ]:
 

Multilingual Politeness Dataset¶

This dataset contains conversation snippets from Wikipedia's editor talk pages. The fields are as follows:

  • text: The Wikipedia's editor talk page conversation snippets.
  • politeness: politeness level from -2 (very rude) to 2 (very polite).
In [12]:
politeness_dataset = exlib.datasets.multilingual_politeness.PolitenessDataset(split="test")
In [13]:
politeness_dataloader = torch.utils.data.DataLoader(politeness_dataset, batch_size=4, shuffle=False)
politeness_model = exlib.datasets.multilingual_politeness.PolitenessClassifier().eval()

for batch in politeness_dataloader: 
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    output = politeness_model(input_ids, attention_mask)
    utterances = [
        politeness_dataset.tokenizer.decode(input_id, skip_special_tokens=True)
        for input_id in input_ids
    ]
    for utterance, label in zip(utterances, output):
        print("Text: {}\nPoliteness: {}\n".format(utterance, label.item()))
    break
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at anonymous and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Text: The intro mentions the ISO 8601 international standard adopted in most western countries. What does this even mean? Who are we suggesting has done the adoption?
Politeness: 0.17250384390354156

Text: I'm a user on PrettyCure.org, and somebody on the site said they are making a fourth season of PreCure. It's a rumuor, but is it true? That person said it's more like Tokyo Mew Mew, a group of girls.
Politeness: -0.004305824637413025

Text: Hello fellow Wikipedians, I have just added archive links to on Essen. Please take a moment to review my edit. If necessary, add after the link to keep me from modifying it.
Politeness: 0.010885253548622131

Text: I saw the template citing this issue and since there was no section here discussing it I've decided to start one. I'm a Canadian and most of our television programs are also aired in the US so my knowledge of what's on TV outside of North America is limited. So I'm not sure of how much help I can be, but I do have some ideas on how to improve this section and I'm open to feedback.
Politeness: -0.00953248143196106

In [ ]:
 

Supernova Dataset¶

This dataset contains astronomical time-series that has 18 types of astronomical sources. The fields are as follows:

  • label: The class of the object.
  • times_wv: 2D array of shape containing observation times (modified Julian days, MJD) and filter (wavelength) for each observation.
  • target: 2D array of shape containing the flux (arbitrary units) and flux error for each observation.
In [14]:
supernova_dataset = exlib.datasets.supernova.SupernovaDataset(split="test")
In [15]:
supernova_item = supernova_dataset.dataset[0]
times_wv = torch.tensor(supernova_item["times_wv"])
xs = times_wv[:, 0]
ys = torch.tensor(supernova_item["target"])[:, 0]
times_wv, xs, ys = times_wv[ys!=0], xs[ys!=0], ys[ys!=0]
unique_wls = [3670.69, 4826.85, 6223.24, 7545.98, 8590.9, 9710.28]
In [16]:
plt.clf()
plt.figure(figsize=(6,4))

for wl in unique_wls:
    mask = times_wv[:, 1] == wl
    plt.scatter(xs[mask], ys[mask], label=f'{wl:.2f}', cmap='viridis')

plt.title(f'Class: {supernova_item["label"]}')
plt.xlabel('Time')
plt.ylabel('Flux')
plt.legend(title='Wavelength')
Out[16]:
<matplotlib.legend.Legend at 0x75e36f6f9dd0>
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [ ]:
 

Part 2: Running benchmarks and evaluations¶

We now show how to run some benchmarks + evaluations.

Vision Example (Cholecystectomy)¶

We show an example of groups generated by the quickshift segmentation algorithm compared to those that are labeled by our surgeon experts.

In [17]:
cholec_dataset = exlib.datasets.cholec.CholecDataset(split="test")
cholec_item = cholec_dataset[0]
cholec_image = cholec_item["image"]
cholec_organs = cholec_item["organs"]

quickshift_feature_extractor = exlib.features.vision.QuickshiftGroups(max_groups=8)
quickshift_groups = quickshift_feature_extractor(cholec_image.unsqueeze(0))[0]

# Image, expert-specified groups, and quickshift-generated groups
cholec_image.shape, cholec_organs.shape, quickshift_groups.shape
Out[17]:
(torch.Size([3, 360, 640]), torch.Size([360, 640]), torch.Size([8, 360, 640]))
In [18]:
# Let's visualize these things
plt.clf()
fig, ax = plt.subplots(1, 3, figsize=(12,4))
for a in ax: a.set_axis_off()

cholec_image = torchvision.transforms.GaussianBlur(41, (50.0, 50.0))(cholec_image)

ax[0].imshow(cholec_image.numpy().transpose(1,2,0))
ax[1].imshow(cholec_item["organs"].numpy())
ax[2].imshow((torch.arange(8).view(-1,1,1) * quickshift_groups).sum(dim=0).numpy())

ax[0].set_title("Original image (blurred)")
ax[1].set_title("Expert-specified groups")
ax[2].set_title("Quickshift-generated groups")
Out[18]:
Text(0.5, 1.0, 'Quickshift-generated groups')
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [19]:
# Expects a pair of (N,C,H,W)-shaped things
cholec_metric = exlib.datasets.cholec.CholecFixScore()

# Adjust things to be able to be put into cholec_metric
cholec_organs_mask = F.one_hot(cholec_organs).permute(2,0,1).unsqueeze(0) # (1,num_true_groups,H,W)
batched_qs_groups = quickshift_groups.unsqueeze(0) # (1,num_pred_groups,H,W)

# The FIX score is computed for each pixel, and we just take the average
cholec_fix_score = cholec_metric(groups_pred=batched_qs_groups, groups_true=cholec_organs_mask)
print(cholec_fix_score.shape, cholec_fix_score.mean())
torch.Size([1]) tensor(0.1941)
In [ ]:
 

Natural Language Example (Multilingual Politeness)¶

We show an example of the groups demarcated at the sentence level.

In [20]:
politeness_dataset = exlib.datasets.multilingual_politeness.PolitenessDataset(split="test")
politeness_item = politeness_dataset[0]
politeness_word_list = [w for w in politeness_item["word_list"] if w.strip()]

sentence_feature_extractor = exlib.features.text.SentenceGroups(distinct=26, scaling=1.5)
sentence_groups = sentence_feature_extractor(politeness_word_list)
In [21]:
print(" ".join(politeness_word_list), "\n")
for i, g in enumerate(sentence_groups):
    sentence = " ".join([politeness_word_list[ni.item()] for ni in g.nonzero()])
    print(f"Sentence {i+1}: {sentence}")
The intro mentions the ISO 8601 international standard adopted in most western countries. What does this even mean? Who are we suggesting has done the adoption? 

Sentence 1: Who are we suggesting has done the adoption?
Sentence 2: What does this even mean?
Sentence 3: The intro mentions the ISO 8601 international standard adopted in most western countries.
In [ ]:
 

Time-series Example (Supernova)¶

We show an example of groups based on slicing

In [22]:
supernova_dataset = exlib.datasets.supernova.SupernovaDataset(split="test")
supernova_item = supernova_dataset.dataset[0]
times_wv = torch.tensor(supernova_item["times_wv"])
xs = times_wv[:, 0]
ys = torch.tensor(supernova_item["target"])[:, 0]
times_wv, xs, ys = times_wv[ys!=0], xs[ys!=0], ys[ys!=0]
unique_wls = [3670.69, 4826.85, 6223.24, 7545.98, 8590.9, 9710.28]

slice_feature_extractor = exlib.features.time_series.SliceGroups(ngroups=10, window_size=100)
In [23]:
supernova_dataloader = exlib.datasets.supernova_helper.create_test_dataloader_raw(
    dataset = supernova_dataset,
    batch_size = 5
)

for batch in supernova_dataloader:
    slice_groups = slice_feature_extractor(**batch)
    slice_groups = slice_groups[0]
    break
original dataset size: 792
remove nans dataset size: 792
In [ ]:
 
In [24]:
plt.clf()
plt.figure(figsize=(6,4))

# First plot the wavelengths like before
for wl in unique_wls:
    mask = times_wv[:, 1] == wl
    plt.scatter(xs[mask], ys[mask], label=f'{wl:.2f}', cmap="viridis")

# Then overlay the groups
cmap = plt.cm.get_cmap("viridis")
for i, g in enumerate(slice_groups):
    if g.sum() == 0: continue
    xmin = xs[g.nonzero().min().item()].item()
    xmax = xs[g.nonzero().max().item()].item()
    plt.axvspan(xmin, xmax, alpha=0.3, facecolor=cmap(i/len(slice_groups)))

plt.title(f'Class: {supernova_item["label"]}')
plt.xlabel('Time')
plt.ylabel('Flux')
plt.legend(title='Wavelength')
Out[24]:
<matplotlib.legend.Legend at 0x75e2a1d74850>
<Figure size 640x480 with 0 Axes>
No description has been provided for this image
In [ ]:
 
In [ ]: