import torch
import numpy as np


def _make_span_from_seeds(seeds, span, total=None):
 inds = list()
 for seed in seeds:
 for i in range(seed, seed + span):
 if total is not None and i >= total:
 break
 elif i not in inds:
 inds.append(int(i))
 return np.array(inds)


def _make_mask(shape, p, total, span, allow_no_inds=False):
 # num_mask_spans = np.sum(np.random.rand(total) < p)
 # num_mask_spans = int(p * total)
 mask = torch.zeros(shape, requires_grad=False, dtype=torch.bool)

 for i in range(shape[0]):
 mask_seeds = list()
 while not allow_no_inds and len(mask_seeds) == 0 and p > 0:
 mask_seeds = np.nonzero(np.random.rand(total) < p)[0]

 mask[i, _make_span_from_seeds(mask_seeds, span, total=total)] = True

 return mask
