import numpy as np
import torch

def shuffle_tensor_2d(x: torch.Tensor) -> torch.Tensor:
    
    flat = x.reshape(-1)
    # Generate a random permutation of indices
    indices = torch.randperm(flat.shape[0])
    
    # Use index_select to shuffle along the specified dimension
    shuffled_tensor = flat.index_select(0,indices)
    return shuffled_tensor.reshape(x.shape[0], x.shape[1])

def top_n_mask(x, n):
    """
    Returns a binary mask for the top n highest values in the map x.
    
    Parameters:
    - x (np.ndarray): The input map to be masked.
    - n (int): The number of highest values to keep.
    
    Returns:
    - np.ndarray: A binary mask with 1s for the top n values and 0s elsewhere.
    """
    if n <= 0:
        raise ValueError("n should be a positive integer")
    
    # Flatten the array to work with it easily
    flattened_x = x.flatten()
    
    if n > len(flattened_x):
        raise ValueError("n should not be greater than the number of elements in x")
    
    # Get the threshold value for the top n values
    threshold = np.partition(flattened_x, -n)[-n]
    
    # Create the binary mask
    mask = (x >= threshold)
    
    return mask