Welcome to FIX!¶
In this notebook, we give some examples of loading dataset and running benchmarks.
Part 1: Loading Datasets¶
We'll first talk about how to load the datasets described in our paper.
FIX is built using the exlib
library, which we load using a local version for now.
# !pip install exlib
import sys; sys.path.insert(0, "src")
import exlib
import torch
import torch.nn.functional as F
import torchvision
import datasets as huggingface_datasets # Not to be confused with exlib.datasets!
import matplotlib.pyplot as plt
import numpy as np
Cholecystectomy Dataset¶
This dataset contains image data from cholecystectomy surgery (gallbladder removal). The fields are as follows:
image
: A is of the surgery.gonogo
: Where it is safe or unsafe to operate. Background (0), safe (1), and unsafe (2).organs
: Relevant organ structures for surgery. Background (0), liver (1), gallbladder (2), and hepatocystic triangle (3). These are the expert-specified interpretable features.
cholec_dataset = exlib.datasets.cholec.CholecDataset(split="test")
cholec_item = cholec_dataset[0]
cholec_image = cholec_item["image"]
cholec_gonogo = cholec_item["gonogo"]
cholec_organs = cholec_item["organs"]
print("image:", cholec_image.shape, cholec_image.dtype)
print("gonogo:", cholec_gonogo.shape, cholec_gonogo.dtype)
print("organs:", cholec_organs.shape, cholec_organs.dtype)
image: torch.Size([3, 360, 640]) torch.float32 gonogo: torch.Size([360, 640]) torch.int64 organs: torch.Size([360, 640]) torch.int64
plt.clf()
fig, ax = plt.subplots(1, 3, figsize=(12,4))
for i in range(3): ax[i].set_axis_off()
cholec_image = torchvision.transforms.GaussianBlur(41, (50.0, 50.0))(cholec_image)
ax[0].imshow(cholec_image.numpy().transpose(1,2,0))
ax[1].imshow(cholec_gonogo.numpy())
ax[2].imshow(cholec_organs.numpy())
ax[0].set_title("Original Image (blurred)")
ax[1].set_title("Back (0), Safe (1), Unsafe (2)")
ax[2].set_title("Back (0), Liver (1), Gallb. (2), Hept. (3)")
Text(0.5, 1.0, 'Back (0), Liver (1), Gallb. (2), Hept. (3)')
<Figure size 640x480 with 0 Axes>
Chest X-ray Dataset¶
This dataset contains vision data for chest X-ray pathology identification. The fields are as follows:
image
: The image of the chest X-ray.pathols
: A binary vector that denotes which of the 14 pathologies are present.struct
: A collection of binary masks over the image for the relevant anatomical structures. These are the expert-specified interpretable features.
chestx_dataset = exlib.datasets.chestx.ChestXDataset(split="test")
# Find an image that has at least 2 pathologies present
for i in range(10000):
chestx_item = chestx_dataset[i]
if chestx_item["pathols"].sum() > 1: break
chestx_image = chestx_item["image"]
chestx_pathols = chestx_item["pathols"]
chestx_structs = chestx_item["structs"]
print("image:", chestx_image.shape, chestx_image.dtype)
print("pathols:", chestx_pathols.shape, chestx_pathols.dtype)
print("structs:", chestx_structs.shape, chestx_structs.dtype)
image: torch.Size([1, 224, 224]) torch.float32 pathols: torch.Size([14]) torch.int64 structs: torch.Size([14, 224, 224]) torch.int64
print("All pathologies:")
print(", ".join([f"({i}) {s}" for i,s in enumerate(chestx_dataset.pathology_names)]))
print("\nAll structures:")
print(", ".join([f"({i}) {s}" for i,s in enumerate(chestx_dataset.structure_names)]))
All pathologies: (0) Atelectasis, (1) Cardiomegaly, (2) Consolidation, (3) Edema, (4) Effusion, (5) Emphysema, (6) Fibrosis, (7) Hernia, (8) Infiltration, (9) Mass, (10) Nodule, (11) Pleural_Thickening, (12) Pneumonia, (13) Pneumothorax All structures: (0) Left Clavicle, (1) Right Clavicle, (2) Left Scapula, (3) Right Scapula, (4) Left Lung, (5) Right Lung, (6) Left Hilus Pulmonis, (7) Right Hilus Pulmonis, (8) Heart, (9) Aorta, (10) Facies Diaphragmatica, (11) Mediastinum, (12) Weasand, (13) Spine
plt.clf()
fig, ax = plt.subplot_mosaic([
(["image"] + [f"struct{i}" for i in range(7)]),
(["."] + [f"struct{i}" for i in range(7,14)]),
], figsize=(14,4))
for _, a in ax.items(): a.set_axis_off()
struct_titles = chestx_dataset.structure_names
ax["image"].imshow(chestx_image.numpy().transpose(1,2,0), cmap="gray")
ax["image"].set_title("Image")
for i in range(14):
mask_t = chestx_structs[i].unsqueeze(0)
ax[f"struct{i}"].imshow((mask_t.numpy().transpose(1,2,0)) * 2, cmap="gray")
ax[f"struct{i}"].set_title(f"{struct_titles[i][:10]} (T)", fontsize=10)
<Figure size 640x480 with 0 Axes>
print(f"Pathologies present:")
for idx in chestx_pathols.nonzero():
print("*", chestx_dataset.pathology_names[idx.item()])
Pathologies present: * Cardiomegaly * Infiltration
Cosmological Mass Maps Dataset¶
This dataset contains clean simulated weak lensing maps without noise. The relevant fields are as follows:
input
: A (1,66,66)-shaped weak lensing maplabel
: A pair of numbers that represents the cosmological parameters Omega_m and sigma_8. In this dataset, the expert-specified features are absent.
massmaps_dataset = huggingface_datasets.load_dataset(exlib.datasets.mass_maps.DATASET_REPO, split="test")
# Plot a few examples of mass maps
plt.clf()
fig, ax = plt.subplots(1,4, figsize=(16,4))
for i in range(len(ax)):
mm_input = torch.tensor(massmaps_dataset[i]["input"])
mm_label = massmaps_dataset[i]["label"]
ax[i].imshow(mm_input.numpy().transpose(1,2,0))
ax[i].set_axis_off()
ax[i].set_title(f"Omega_m = {mm_label[0]:.3f}, sigma_8 = {mm_label[1]:.3f}")
<Figure size 640x480 with 0 Axes>
Emotions Dataset¶
This dataset contains 58k carefully curated Reddit comments labeled for 27 emotion categories or Neutral. The fields are as follows:
text
: The reddit comment.labels
: The emotion annotations.comment_id
: Unique identifier of the comment.
emotion_dataset = exlib.datasets.emotion.EmotionDataset(split="test")
emotion_dataloader = torch.utils.data.DataLoader(emotion_dataset, batch_size=4, shuffle=False)
emotion_model = exlib.datasets.emotion.EmotionClassifier().eval()
for batch in emotion_dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
output = emotion_model(input_ids, attention_mask)
utterances = [
emotion_dataset.tokenizer.decode(input_id, skip_special_tokens=True)
for input_id in input_ids
]
for utterance, label in zip(utterances, output.logits):
id_str = emotion_model.model.config.id2label[label.argmax().item()]
print("Text: {}\nEmotion: {}\n".format(utterance, id_str))
break
SamLowe/roberta-base-go_emotions Text: I’m really sorry about your situation :( Although I love the names Sapphira, Cirilla, and Scarlett! Emotion: remorse Text: It's wonderful because it's awful. At not with. Emotion: admiration Text: Kings fan here, good luck to you guys! Will be an interesting game to watch! Emotion: optimism Text: I didn't know that, thank you for teaching me something today! Emotion: gratitude
Multilingual Politeness Dataset¶
This dataset contains conversation snippets from Wikipedia's editor talk pages. The fields are as follows:
text
: The Wikipedia's editor talk page conversation snippets.politeness
: politeness level from -2 (very rude) to 2 (very polite).
politeness_dataset = exlib.datasets.multilingual_politeness.PolitenessDataset(split="test")
politeness_dataloader = torch.utils.data.DataLoader(politeness_dataset, batch_size=4, shuffle=False)
politeness_model = exlib.datasets.multilingual_politeness.PolitenessClassifier().eval()
for batch in politeness_dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
output = politeness_model(input_ids, attention_mask)
utterances = [
politeness_dataset.tokenizer.decode(input_id, skip_special_tokens=True)
for input_id in input_ids
]
for utterance, label in zip(utterances, output):
print("Text: {}\nPoliteness: {}\n".format(utterance, label.item()))
break
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at anonymous and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Text: The intro mentions the ISO 8601 international standard adopted in most western countries. What does this even mean? Who are we suggesting has done the adoption? Politeness: 0.17250384390354156 Text: I'm a user on PrettyCure.org, and somebody on the site said they are making a fourth season of PreCure. It's a rumuor, but is it true? That person said it's more like Tokyo Mew Mew, a group of girls. Politeness: -0.004305824637413025 Text: Hello fellow Wikipedians, I have just added archive links to on Essen. Please take a moment to review my edit. If necessary, add after the link to keep me from modifying it. Politeness: 0.010885253548622131 Text: I saw the template citing this issue and since there was no section here discussing it I've decided to start one. I'm a Canadian and most of our television programs are also aired in the US so my knowledge of what's on TV outside of North America is limited. So I'm not sure of how much help I can be, but I do have some ideas on how to improve this section and I'm open to feedback. Politeness: -0.00953248143196106
Supernova Dataset¶
This dataset contains astronomical time-series that has 18 types of astronomical sources. The fields are as follows:
label
: The class of the object.times_wv
: 2D array of shape containing observation times (modified Julian days, MJD) and filter (wavelength) for each observation.target
: 2D array of shape containing the flux (arbitrary units) and flux error for each observation.
supernova_dataset = exlib.datasets.supernova.SupernovaDataset(split="test")
supernova_item = supernova_dataset.dataset[0]
times_wv = torch.tensor(supernova_item["times_wv"])
xs = times_wv[:, 0]
ys = torch.tensor(supernova_item["target"])[:, 0]
times_wv, xs, ys = times_wv[ys!=0], xs[ys!=0], ys[ys!=0]
unique_wls = [3670.69, 4826.85, 6223.24, 7545.98, 8590.9, 9710.28]
plt.clf()
plt.figure(figsize=(6,4))
for wl in unique_wls:
mask = times_wv[:, 1] == wl
plt.scatter(xs[mask], ys[mask], label=f'{wl:.2f}', cmap='viridis')
plt.title(f'Class: {supernova_item["label"]}')
plt.xlabel('Time')
plt.ylabel('Flux')
plt.legend(title='Wavelength')
<matplotlib.legend.Legend at 0x75e36f6f9dd0>
<Figure size 640x480 with 0 Axes>
Part 2: Running benchmarks and evaluations¶
We now show how to run some benchmarks + evaluations.
Vision Example (Cholecystectomy)¶
We show an example of groups generated by the quickshift segmentation algorithm compared to those that are labeled by our surgeon experts.
cholec_dataset = exlib.datasets.cholec.CholecDataset(split="test")
cholec_item = cholec_dataset[0]
cholec_image = cholec_item["image"]
cholec_organs = cholec_item["organs"]
quickshift_feature_extractor = exlib.features.vision.QuickshiftGroups(max_groups=8)
quickshift_groups = quickshift_feature_extractor(cholec_image.unsqueeze(0))[0]
# Image, expert-specified groups, and quickshift-generated groups
cholec_image.shape, cholec_organs.shape, quickshift_groups.shape
(torch.Size([3, 360, 640]), torch.Size([360, 640]), torch.Size([8, 360, 640]))
# Let's visualize these things
plt.clf()
fig, ax = plt.subplots(1, 3, figsize=(12,4))
for a in ax: a.set_axis_off()
cholec_image = torchvision.transforms.GaussianBlur(41, (50.0, 50.0))(cholec_image)
ax[0].imshow(cholec_image.numpy().transpose(1,2,0))
ax[1].imshow(cholec_item["organs"].numpy())
ax[2].imshow((torch.arange(8).view(-1,1,1) * quickshift_groups).sum(dim=0).numpy())
ax[0].set_title("Original image (blurred)")
ax[1].set_title("Expert-specified groups")
ax[2].set_title("Quickshift-generated groups")
Text(0.5, 1.0, 'Quickshift-generated groups')
<Figure size 640x480 with 0 Axes>
# Expects a pair of (N,C,H,W)-shaped things
cholec_metric = exlib.datasets.cholec.CholecFixScore()
# Adjust things to be able to be put into cholec_metric
cholec_organs_mask = F.one_hot(cholec_organs).permute(2,0,1).unsqueeze(0) # (1,num_true_groups,H,W)
batched_qs_groups = quickshift_groups.unsqueeze(0) # (1,num_pred_groups,H,W)
# The FIX score is computed for each pixel, and we just take the average
cholec_fix_score = cholec_metric(groups_pred=batched_qs_groups, groups_true=cholec_organs_mask)
print(cholec_fix_score.shape, cholec_fix_score.mean())
torch.Size([1]) tensor(0.1941)
Natural Language Example (Multilingual Politeness)¶
We show an example of the groups demarcated at the sentence level.
politeness_dataset = exlib.datasets.multilingual_politeness.PolitenessDataset(split="test")
politeness_item = politeness_dataset[0]
politeness_word_list = [w for w in politeness_item["word_list"] if w.strip()]
sentence_feature_extractor = exlib.features.text.SentenceGroups(distinct=26, scaling=1.5)
sentence_groups = sentence_feature_extractor(politeness_word_list)
print(" ".join(politeness_word_list), "\n")
for i, g in enumerate(sentence_groups):
sentence = " ".join([politeness_word_list[ni.item()] for ni in g.nonzero()])
print(f"Sentence {i+1}: {sentence}")
The intro mentions the ISO 8601 international standard adopted in most western countries. What does this even mean? Who are we suggesting has done the adoption? Sentence 1: Who are we suggesting has done the adoption? Sentence 2: What does this even mean? Sentence 3: The intro mentions the ISO 8601 international standard adopted in most western countries.
Time-series Example (Supernova)¶
We show an example of groups based on slicing
supernova_dataset = exlib.datasets.supernova.SupernovaDataset(split="test")
supernova_item = supernova_dataset.dataset[0]
times_wv = torch.tensor(supernova_item["times_wv"])
xs = times_wv[:, 0]
ys = torch.tensor(supernova_item["target"])[:, 0]
times_wv, xs, ys = times_wv[ys!=0], xs[ys!=0], ys[ys!=0]
unique_wls = [3670.69, 4826.85, 6223.24, 7545.98, 8590.9, 9710.28]
slice_feature_extractor = exlib.features.time_series.SliceGroups(ngroups=10, window_size=100)
supernova_dataloader = exlib.datasets.supernova_helper.create_test_dataloader_raw(
dataset = supernova_dataset,
batch_size = 5
)
for batch in supernova_dataloader:
slice_groups = slice_feature_extractor(**batch)
slice_groups = slice_groups[0]
break
original dataset size: 792 remove nans dataset size: 792
plt.clf()
plt.figure(figsize=(6,4))
# First plot the wavelengths like before
for wl in unique_wls:
mask = times_wv[:, 1] == wl
plt.scatter(xs[mask], ys[mask], label=f'{wl:.2f}', cmap="viridis")
# Then overlay the groups
cmap = plt.cm.get_cmap("viridis")
for i, g in enumerate(slice_groups):
if g.sum() == 0: continue
xmin = xs[g.nonzero().min().item()].item()
xmax = xs[g.nonzero().max().item()].item()
plt.axvspan(xmin, xmax, alpha=0.3, facecolor=cmap(i/len(slice_groups)))
plt.title(f'Class: {supernova_item["label"]}')
plt.xlabel('Time')
plt.ylabel('Flux')
plt.legend(title='Wavelength')
<matplotlib.legend.Legend at 0x75e2a1d74850>
<Figure size 640x480 with 0 Axes>