import torch

def mask_center_patches(imgs, mask_ratio):
    assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"
    
    n, c, h, w = imgs.shape
    patch_size = 14
    
    # Calculate the number of patches along height and width
    num_patches_height = h // patch_size
    num_patches_width = w // patch_size
    
    # Calculate the total number of patches and the number of patches to mask
    total_patches = num_patches_height * num_patches_width
    patches_to_mask = int(total_patches * mask_ratio)
    
    # Calculate the number of patches to mask along height and width, ensuring an odd number for symmetry
    patches_to_mask_height = int((patches_to_mask ** 0.5) // 1)
    patches_to_mask_width = patches_to_mask // patches_to_mask_height
    if patches_to_mask_height % 2 == 0: patches_to_mask_height += 1
    if patches_to_mask_width % 2 == 0: patches_to_mask_width += 1
    
    # Adjust the number of patches to mask if necessary to maintain the mask_ratio
    while patches_to_mask_height * patches_to_mask_width < patches_to_mask:
        if patches_to_mask_height <= patches_to_mask_width:
            patches_to_mask_height += 2
        else:
            patches_to_mask_width += 2
    
    # Create mask
    mask = torch.ones((n, c, h, w), device=imgs.device)
    
    # Calculate the start and end indices for the mask
    start_y = num_patches_height // 2 - patches_to_mask_height // 2
    start_x = num_patches_width // 2 - patches_to_mask_width // 2
    end_y = start_y + patches_to_mask_height
    end_x = start_x + patches_to_mask_width
    
    # Apply the mask to the center patches
    mask[:, :, start_y * patch_size:end_y * patch_size, start_x * patch_size:end_x * patch_size] = 0
    
    # Set the center patches of the image to 1 where mask is 0, and keep the original image elsewhere
    imgs_masked = imgs * mask + (1 - mask)
    
    return imgs_masked, mask