from pathlib import Path
from typing import Optional, Sequence, Union

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import Tensor


def make_recon_img(slot, mask):
    """Returns an image from composing slots (weighted sum) according to the masks.

    Args:
        slot (Tensor): The slot-wise images.
        mask (Tensor): The masks. These are weights that should sum to 1 along the
            slot dimension, but this is not enforced.

    Returns:
        The image resulting from a weighted sum of the slots using the masks as weights.
    """
    b, s, ch, h, w = slot.shape  # B, slots, 3, H, W
    assert mask.shape == (b, s, 1, h, w)  # B, slots, 1, H, W
    return (slot * mask).sum(dim=1)  # B, 3, H, W