import torch


def set_top_large_to_one(tensor, rho):
    flat_tensor = tensor.flatten()
    sorted_tensor, indices = torch.sort(flat_tensor, descending=True)
    num_elements = flat_tensor.numel()
    num_top_elements = int((1 - rho) * num_elements)
    new_tensor = torch.zeros_like(flat_tensor)
    new_tensor[indices[:num_top_elements]] = 1
    return new_tensor.view(tensor.size())


def calculate_weight_matrix(X, W):
    metric = W.abs() * X.norm(p=2, dim=0)
    return metric
