import torch


def discretize_periods(periods, n_frames: int):
    x = periods.unsqueeze(0)
    y = batch_discretize_periods(x, n_frames=n_frames)
    y = y.squeeze(0)
    return y


def batch_discretize_periods(periods, n_frames: int):
    """
    Return the intersection of periods
    """
    frames = torch.arange(n_frames, device=periods.device)
    s = torch.maximum(periods[..., 0, None], frames)
    e = torch.minimum(periods[..., 1, None], frames + 1)
    intersection = torch.maximum(torch.zeros_like(e), e - s)
    intersection = intersection.sum(dim=1)
    return intersection
