import torch


# 给定置信度，计算掩码。置信度越大，就越应该去除，mask=0
def get_global_FFN_mask_by_distribution(probs, rate):
    flatten_probs = probs.flatten()
    print(flatten_probs.size())

    top_k = int(flatten_probs.shape[0] * rate)
    values, indices = torch.topk(flatten_probs, k=top_k)  # 获取最大的 top_k 个值和索引
    print(values)
    # 初始化mask 全为 1
    mask = torch.ones_like(probs)

    # 将 top_k 最大值的对应位置设为 0
    # 先将 W2 展平，然后通过 indices 置零，最后恢复形状
    flat_mask = mask.flatten()
    flat_mask[indices] = 0
    mask = flat_mask.reshape(mask.shape)

    zeros_per_row = (mask == 0).sum(dim=1)
    print(zeros_per_row)
    print(zeros_per_row.sum())
    return mask


# 取norm最小的中间特征删除
def get_global_FFN_mask_by_norm(probs, rate):
    flatten_probs = probs.flatten()
    print(flatten_probs.size())

    top_k = int(flatten_probs.shape[0] * rate)
    values, indices = torch.topk(flatten_probs, largest=False, k=top_k)  # 获取最大的 top_k 个值和索引
    print(values)
    # 初始化mask 全为 1
    mask = torch.ones_like(probs)

    # 将 top_k 最大值的对应位置设为 0
    # 先将 W2 展平，然后通过 indices 置零，最后恢复形状
    flat_mask = mask.flatten()
    flat_mask[indices] = 0
    mask = flat_mask.reshape(mask.shape)

    zeros_per_row = (mask == 0).sum(dim=1)
    print(zeros_per_row)
    print(zeros_per_row.sum())
    return mask


# 给定置信度，计算掩码。置信度越大，就越应该去除，mask=0
def get_local_FFN_mask_by_norm(probs, rate):
    top_k = int(probs.shape[-1] * rate)
    # print(top_k)
    values, indices = torch.topk(probs, largest=False, k=top_k)  # 获取最大的 top_k 个值和索引
    print(values)
    mask = torch.ones_like(probs)
    for index in range(probs.shape[0]):
        mask[index, indices[index]] = 0

    zeros_per_row = (mask == 0).sum(dim=1)
    print(zeros_per_row)
    print(zeros_per_row.sum())
    return mask


# 读取配置文件，并且返回掩码
def load_config(dataset, rate):
    pth = dataset + "_rate" + str(rate) + "_config.pt"
    mask = torch.load(pth)
    return mask