import torch
import numpy as np


def get_new_mask(length, ratio):
    # ratio代表掩码的概率
    mask = torch.zeros((length, length))
    unmask = torch.ones((length, length))
    mask = mask.reshape(-1)
    indices = np.random.choice(np.arange(torch.tensor(mask.shape).item()), replace=False,
                               size=int(torch.tensor(mask.shape).item() * (1 - ratio)))
    mask[indices] = 1
    mask = mask.reshape(length, length)
    unmask = unmask - mask
    return mask, unmask


get_new_mask(8, 0.1)