import torch

def mask_center_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 16
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the total number of patches and the number of patches to mask
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along height and width, ensuring an odd number for symmetry
    patches_to_mask_height = int((patches_to_mask ** 0.5) // 1)
    patches_to_mask_width = patches_to_mask // patches_to_mask_height
    if patches_to_mask_height % 2 == 0: patches_to_mask_height += 1
    if patches_to_mask_width % 2 == 0: patches_to_mask_width += 1
    
    # Adjust the number of patches to mask if necessary to maintain the mask_ratio
    while patches_to_mask_height * patches_to_mask_width < patches_to_mask:
        if patches_to_mask_height <= patches_to_mask_width:
            patches_to_mask_height += 2
        else:
            patches_to_mask_width += 2
    
    # Create mask
    mask = torch.ones((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask
    start_y = num_patches_height // 2 - patches_to_mask_height // 2
    start_x = num_patches_width // 2 - patches_to_mask_width // 2
    end_y = start_y + patches_to_mask_height
    end_x = start_x + patches_to_mask_width
    
    # Apply the mask to the center patches
    mask[:, :, start_y * patch_size:end_y * patch_size, start_x * patch_size:end_x * patch_size] = 0
    
    # Set the center patches of the image to 1 where mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask


def mask_verge_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 16
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the total number of patches and the number of patches to mask
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along height and width, ensuring an odd number for symmetry
    patches_to_mask_height = int((patches_to_mask ** 0.5) // 1)
    patches_to_mask_width = patches_to_mask // patches_to_mask_height
    if patches_to_mask_height % 2 == 0: patches_to_mask_height += 1
    if patches_to_mask_width % 2 == 0: patches_to_mask_width += 1
    
    # Adjust the number of patches to mask if necessary to maintain the mask_ratio
    while patches_to_mask_height * patches_to_mask_width < patches_to_mask:
        if patches_to_mask_height <= patches_to_mask_width:
            patches_to_mask_height += 2
        else:
            patches_to_mask_width += 2
    
    # Create mask
    mask = torch.zeros((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask
    start_y = num_patches_height // 2 - patches_to_mask_height // 2
    start_x = num_patches_width // 2 - patches_to_mask_width // 2
    end_y = start_y + patches_to_mask_height
    end_x = start_x + patches_to_mask_width
    
    # Apply the mask to the center patches
    mask[:, :, start_y * patch_size:end_y * patch_size, start_x * patch_size:end_x * patch_size] = 1
    
    # Set the center patches of the image to 1 where mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask


def mask_top_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 16
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the total number of patches and the number of patches to mask based on the mask_ratio
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along height
    patches_to_mask_height = patches_to_mask // num_patches_width
    
    # Ensure we have an even number to mask the top symmetrically if needed
    if patches_to_mask_height % 2 != 0:
        patches_to_mask_height += 1

    # If we have more patches to mask than the height allows, adjust the number to fit
    if patches_to_mask_height > num_patches_height:
        patches_to_mask_height = num_patches_height
    
    # Create mask
    mask = torch.ones((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask along the height
    start_y = 0
    end_y = patches_to_mask_height
    
    # Apply the mask to the top patches
    mask[:, :, start_y * patch_size:end_y * patch_size, :] = 0
    
    # Set the top patches of the image to 1 where the mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask


def mask_bottom_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 16
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the number of patches to mask based on the mask_ratio
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along height
    patches_to_mask_height = patches_to_mask // num_patches_width
    
    # Ensure we have an even number to mask the bottom symmetrically if needed
    if patches_to_mask_height % 2 != 0:
        patches_to_mask_height += 1

    # If we have more patches to mask than the height allows, we adjust the number of patches to fit
    if patches_to_mask_height > num_patches_height:
        patches_to_mask_height = num_patches_height
    
    # Create mask
    mask = torch.ones((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask along the height
    start_y = num_patches_height - patches_to_mask_height
    end_y = num_patches_height
    
    # Apply the mask to the bottom patches
    mask[:, :, start_y * patch_size:end_y * patch_size, :] = 0
    
    # Set the bottom patches of the image to 1 where mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask


def mask_left_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 16
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the total number of patches and the number of patches to mask based on the mask_ratio
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along width
    patches_to_mask_width = patches_to_mask // num_patches_height
    
    # Ensure we have an even number to mask the left side symmetrically if needed
    if patches_to_mask_width % 2 != 0:
        patches_to_mask_width += 1

    # If we have more patches to mask than the width allows, we adjust the number of patches to fit
    if patches_to_mask_width > num_patches_width:
        patches_to_mask_width = num_patches_width
    
    # Create mask
    mask = torch.ones((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask along the width
    start_x = 0
    end_x = patches_to_mask_width
    
    # Apply the mask to the left patches
    mask[:, :, :, start_x * patch_size:end_x * patch_size] = 0
    
    # Set the left patches of the image to 1 where mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask


def mask_right_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 16
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the total number of patches and the number of patches to mask based on the mask_ratio
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along width
    patches_to_mask_width = patches_to_mask // num_patches_height
    
    # If we have more patches to mask than the width allows, adjust the count to fit
    if patches_to_mask_width > num_patches_width:
        patches_to_mask_width = num_patches_width
    
    # Ensure an even number to mask the right side symmetrically
    if patches_to_mask_width % 2 != 0:
        patches_to_mask_width += 1
    
    # Create mask
    mask = torch.ones((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask along the width
    start_x = num_patches_width - patches_to_mask_width
    end_x = num_patches_width
    
    # Apply the mask to the right patches
    mask[:, :, :, start_x * patch_size:end_x * patch_size] = 0
    
    # Set the right patches of the image to 1 where the mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask


def center_mask(images, mask_ratio):
    """
    Applies a mask to the center of the images based on the given mask_ratio.
    
    :param images: A 2D tensor with shape (n, 256) where n is the number of images and
                   each image is flattened into a 1D tensor of block IDs.
    :param mask_ratio: A float indicating the ratio of blocks to mask in the center of the image.
    :return: A tuple (masked_images, masks) where:
             - masked_images is the 2D tensor with central blocks masked with ID 1.
             - masks is a 2D binary tensor indicating which blocks have been masked.
    """
    # Calculate the number of blocks to mask on each side
    num_blocks = int(16 * mask_ratio)
    start_index = (16 - num_blocks) // 2
    end_index = start_index + num_blocks

    # Initialize the mask tensor with zeros
    masks = torch.zeros_like(images) * 1.0
    
    # Iterate over each image and apply the mask
    for i in range(images.shape[0]):
        for j in range(start_index, end_index):
            # Calculate the flattened index range for the current row
            flat_start_index = j * 16 + start_index
            flat_end_index = flat_start_index + num_blocks
            
            # Apply the mask to the image and update the mask tensor
            # images[i, flat_start_index:flat_end_index] = mask_token_id
            masks[i, flat_start_index:flat_end_index] = 1.0
    # 注：这里 mask 矩阵中为 1 的位置表示擦除
    return images, masks


def verge_mask(images, mask_ratio):
    """
    Applies a mask to the center of the images based on the given mask_ratio.
    
    :param images: A 2D tensor with shape (n, 256) where n is the number of images and
                   each image is flattened into a 1D tensor of block IDs.
    :param mask_ratio: A float indicating the ratio of blocks to mask in the center of the image.
    :return: A tuple (masked_images, masks) where:
             - masked_images is the 2D tensor with central blocks masked with ID 1.
             - masks is a 2D binary tensor indicating which blocks have been masked.
    """
    mask_ratio = 1 - mask_ratio
    # Calculate the number of blocks to mask on each side
    num_blocks = int(16 * mask_ratio)
    start_index = (16 - num_blocks) // 2
    end_index = start_index + num_blocks

    # Initialize the mask tensor with zeros
    masks = torch.ones_like(images)
    
    # Iterate over each image and apply the mask
    for i in range(images.shape[0]):
        for j in range(start_index, end_index):
            # Calculate the flattened index range for the current row
            flat_start_index = j * 16 + start_index
            flat_end_index = flat_start_index + num_blocks
            
            # Apply the mask to the image and update the mask tensor
            # images[i, flat_start_index:flat_end_index] = mask_token_id
            masks[i, flat_start_index:flat_end_index] = 0
    # 注：这里 mask 矩阵中为 1 的位置表示擦除
    return images, masks


def top_mask(images, mask_ratio):
    """
    Applies a mask to the top of the images based on the given mask_ratio.
    
    :param images: A 2D tensor with shape (n, 256) where n is the number of images and
                   each image is flattened into a 1D tensor of block IDs.
    :param mask_ratio: A float indicating the ratio of blocks to mask on the top of the image.
    :return: A tuple (masked_images, masks) where:
             - masked_images is the 2D tensor with top blocks masked with ID 1.
             - masks is a 2D binary tensor indicating which blocks have been masked.
    """
    # Calculate the total number of blocks and number of blocks to mask
    total_blocks = images.shape[1]  # 256 for a 16x16 image
    blocks_per_row = int(torch.sqrt(torch.tensor(total_blocks)))
    total_rows = blocks_per_row
    num_rows_to_mask = int(total_rows * mask_ratio)
    num_blocks_to_mask = blocks_per_row * num_rows_to_mask

    # Initialize the mask tensor with zeros
    masks = torch.zeros_like(images)

    # Apply the mask
    for i in range(images.shape[0]):
        masks[i, :num_blocks_to_mask] = 1

    return images, masks


def bottom_mask(images, mask_ratio):
    """
    Applies a mask to the bottom of the images based on the given mask_ratio.
    
    :param images: A 2D tensor with shape (n, 256) where n is the number of images and
                   each image is flattened into a 1D tensor of block IDs.
    :param mask_ratio: A float indicating the ratio of blocks to mask at the bottom of the image.
    :return: A tuple (masked_images, masks) where:
             - masked_images is the 2D tensor with bottom blocks masked with ID 1.
             - masks is a 2D binary tensor indicating which blocks have been masked.
    """
    # Calculate the number of blocks to mask
    total_blocks = images.shape[1]  # 256 in this case
    blocks_per_row = int(torch.sqrt(torch.tensor(total_blocks)))  # Assuming a square image
    num_blocks_to_mask = int(total_blocks * mask_ratio)

    # Calculate the start index for masking
    start_index = total_blocks - num_blocks_to_mask
    
    # Initialize the mask tensor with zeros
    masks = torch.zeros_like(images)
    
    # Iterate over each image and apply the mask
    for i in range(images.shape[0]):
        masks[i, start_index:] = 1.0

    return images, masks


def left_mask(images, mask_ratio):
    """
    Applies a mask to the left side of the images based on the given mask_ratio.
    
    :param images: A 2D tensor with shape (n, 256) where n is the number of images and
                   each image is flattened into a 1D tensor of block IDs.
    :param mask_ratio: A float indicating the ratio of blocks to mask on the left side of the image.
    :return: A tuple (masked_images, masks) where:
             - masked_images is the 2D tensor with left blocks masked with ID 1.
             - masks is a 2D binary tensor indicating which blocks have been masked.
    """
    # Calculate the number of blocks to mask per row
    total_blocks = images.shape[1]  # Assuming 256 for square image
    blocks_per_row = int(torch.sqrt(torch.tensor(total_blocks)))  # 16 blocks per row for a 16x16 image
    num_blocks_to_mask_per_row = int(blocks_per_row * mask_ratio)
    
    # Initialize the mask tensor with zeros
    masks = torch.zeros_like(images)
    
    # Apply the mask
    for i in range(images.shape[0]):
        for j in range(blocks_per_row):
            # Mask the first 'num_blocks_to_mask_per_row' blocks in each row
            start_index = j * blocks_per_row
            end_index = start_index + num_blocks_to_mask_per_row
            masks[i, start_index:end_index] = 1.0

    return images, masks


def right_mask(images, mask_ratio):
    """
    Applies a mask to the right side of the images based on the given mask_ratio.
    
    :param images: A 2D tensor with shape (n, 256) where n is the number of images and
                   each image is flattened into a 1D tensor of block IDs.
    :param mask_ratio: A float indicating the ratio of blocks to mask on the right side of the image.
    :return: A tuple (masked_images, masks) where:
             - masked_images is the 2D tensor with right blocks masked with ID 1.
             - masks is a 2D binary tensor indicating which blocks have been masked.
    """
    # Calculate the number of blocks to mask per row
    total_blocks = images.shape[1]  # 256 in this case
    blocks_per_row = int(torch.sqrt(torch.tensor(total_blocks)))  # Assuming a square image
    num_blocks_to_mask_per_row = int(blocks_per_row * mask_ratio)

    # Initialize the mask tensor with zeros
    masks = torch.zeros_like(images)
    
    # Iterate over each image and apply the mask
    for i in range(images.shape[0]):
        for j in range(blocks_per_row):
            # Calculate the index range for the current row to mask
            end_index = (j + 1) * blocks_per_row
            start_index = end_index - num_blocks_to_mask_per_row
            
            # Set the mask
            masks[i, start_index:end_index] = 1
    
    return images, masks