import numpy as np

# 1. Create a dummy attribution map [Time=10, Channels=1]
# Using specific values to make verification easy
attr_map = np.array([0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.5, 1.0])
print(f"Original Map: {attr_map}")

# 2. Define k (20% for this example)
k = 0.2
flat_attr = attr_map.flatten()
num_to_mask = int(len(flat_attr) * k) # 20% of 10 elements = 2

# 3. Use np.partition to find the threshold
# np.partition(arr, n) ensures the element at index n is in its 
# final sorted position. Elements smaller than it move to the left.
partitioned = np.partition(flat_attr, num_to_mask)
threshold = partitioned[num_to_mask]

print(f"Number of elements to mask: {num_to_mask}")
print(f"Threshold value (at rank {num_to_mask}): {threshold}")

# 4. Generate the Mask (0 for values <= threshold, 1 otherwise)
mask = (attr_map > threshold).astype(np.float32)
print(f"Generated Mask: {mask}")
print(f"Verification: {int(np.sum(mask == 0))} elements were masked.")
