
import torch



def calculate_cas_elementwise(input_matrix, beta1=0.7, beta2=0.3):
    """
    """

    cas = torch.zeros_like(input_matrix[0])  # Initialize CAS tensor with same shape as single query

    for indices in torch.cartesian_prod(*[torch.arange(dim) for dim in input_matrix[0].shape]):
        values_at_position = input_matrix[(..., *indices)]  # Slice all values at this position
        mean = torch.mean(values_at_position)
        std = torch.std(values_at_position, unbiased=False)
        cas[indices] = beta1 * mean - beta2 * std

    return cas


def select_kns(input_matrix, beta1=0.7, beta2=0.3, threshold_factor=0.3):
    cas=calculate_cas_elementwise(input_matrix=input_matrix,beta1=beta1, beta2=beta2)
    max_cas = torch.max(cas)
    threshold = threshold_factor * max_cas

    # Create a boolean mask where True indicates a selected KN
    selected_kns = cas >= threshold

    return selected_kns