# Attention mask 1 is unmasked 0 is masked

num_heads, num_queries, sequence_length = attentions.shape
num_queries_to_use = sequence_length * masking_ratio / num_per_query

head_averaged_attentions = mean(attentions, axis=0)

chosen_attentions = head_averaged_attentions[i, random_permutation(sequence_length)[:num_queries_to_use]]
query_summed_attentions = sum(chosen_attentions, axis=0)
masked_indices = topk(query_summed_attentions, k=masking_ratio * sequence_length)

attention_mask[masked_indices] = 0