In [1]:
# Uncomment line below to install exlib
# !pip install exlib

Mass Maps¶

This is a notebook for MassMaps. The inputs are simulated weak lensing maps of the shape (batch_size, 1, 66, 66). The model predicts $\Omega_m$ and $\sigma_8$ from the weak lensing map.

In [ ]:
import sys
sys.path.insert(0, "../../src")
import exlib
import torch
from datasets import load_dataset
from exlib.datasets.mass_maps import *

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Load Data¶

In [3]:
# Load data
train_dataset = load_dataset(DATASET_REPO, split='train')
val_dataset = load_dataset(DATASET_REPO, split='validation')
test_dataset = load_dataset(DATASET_REPO, split='test')

train_dataset.set_format('torch', columns=['input', 'label'])
val_dataset.set_format('torch', columns=['input', 'label'])
test_dataset.set_format('torch', columns=['input', 'label'])

Baselines for getting expert features¶

In [4]:
# Baselines
from exlib.features.vision.mass_maps import MassMapsOracle, MassMapsOne
from exlib.features.vision.watershed import WatershedGroups
from exlib.features.vision.quickshift import QuickshiftGroups
from exlib.features.vision.patch import PatchGroups
In [5]:
watershed_baseline = WatershedGroups(min_dist=10, compactness=0).to(device)

Show FIXScores for some examples¶

In [6]:
import math
import matplotlib.pyplot as plt
from exlib.datasets.mass_maps import MassMapsFixScore

def show_example(groups, X, img_idx=0):
    massmaps_align = MassMapsFixScore().to(device)
    alignment_results = massmaps_align(groups, X, reduce='none', return_dict=True)
    
    m = groups.shape[1]
    cols = 8
    rows = math.ceil(m / cols)
    fig, axs = plt.subplots(rows, cols, figsize=(cols*3, rows*4))
    axs = axs.ravel()

    image = X[img_idx]
    for idx in range(len(axs)):
        if idx < m:
            mask = groups[img_idx][idx]

            if mask.sum() > 0:
                axs[idx].imshow(image[0].cpu().numpy())
                axs[idx].contour(mask.cpu().numpy() > 0, 2, colors='red')
                axs[idx].contourf(mask.cpu().numpy() > 0, 2, hatches=['//', None, None],
                                cmap='gray', extend='neither', linestyles='--', alpha=0.01)

                p_void_ = alignment_results['p_void_']
                p_cluster_ = alignment_results['p_cluster_']
                purity = alignment_results['purity']
                # total_score = alignment_scores_void[0][idx].item() * massmaps_align.void_scale + alignment_scores_cluster[0][idx].item() * massmaps_align.cluster_scale
                axs[idx].set_title(f'void {p_void_[0][idx].item():.5f}\ncluster {p_cluster_[0][idx].item():.5f}\npurity {purity[0][idx].item():.5f}')
        axs[idx].axis('off')
    plt.show()
In [7]:
import torch
from exlib.datasets.mass_maps import MassMapsFixScore
watershed_baseline = WatershedGroups(min_dist=10, compactness=0).to(device).to(device)
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
groups = watershed_baseline(X)
groups_show = torch.stack([groups[0:1,2], groups[0:1,6]], dim=1)
show_example(groups_show, X[0:1])
No description has been provided for this image
In [8]:
show_example(groups, X)
No description has been provided for this image
In [9]:
quickshift_baseline = QuickshiftGroups(
    kernel_size=5,
    max_dist=10,
    sigma=5).to(device)
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
groups = quickshift_baseline(X)
show_example(groups, X)
No description has been provided for this image
In [10]:
patch_baseline = PatchGroups(grid_size=(8, 8), mode='grid').to(device)
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
groups = patch_baseline(X)
show_example(groups, X)
No description has been provided for this image
In [11]:
oracle_baseline = MassMapsOracle().to(device)
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
groups = oracle_baseline(X)
show_example(groups, X)
No description has been provided for this image
In [12]:
one_baseline = MassMapsOne().to(device)
X, y = train_dataset[0:2]['input'], train_dataset[0:2]['label']
groups = one_baseline(X)
show_example(groups, X)
No description has been provided for this image

Compare baselines¶

In [13]:
mass_maps_fixscores = get_mass_maps_scores(subset=True)
100%|██████████████████████████████████████████████████████████████████████████████████| 64/64 [00:27<00:00,  2.32it/s]
Omega_m loss 0.0053, sigma_8 loss 0.0129, avg loss 0.0091

In [14]:
for name in mass_maps_fixscores:
    metric = torch.tensor(mass_maps_fixscores[name])
    mean_metric = metric.mean()
    print(f'{name}\t{mean_metric.item():.4f}')
identity	0.5427
random	0.5448
patch	0.5499
quickshift	0.5433
watershed	0.5523
sam	0.5474
ace	0.5457
craft	0.4004
archipelago	0.5489
In [15]:
from exlib.features.vision.watershed import WatershedGroups
watershed_baseline = WatershedGroups(min_dist=10, compactness=0)
mass_maps_fixscores_watershed = get_mass_maps_scores(['watershed'], subset=True)
100%|██████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 90.50it/s]
Omega_m loss 0.0053, sigma_8 loss 0.0129, avg loss 0.0091

In [16]:
torch.mean(mass_maps_fixscores_watershed['watershed'])
Out[16]:
tensor(0.5523)
In [ ]: