import torch
from torch.distributions import Normal
import matplotlib.pyplot as plt

def get_l2_norm_by_row(weight):
    assert weight.shape != 2, "维度不为2，请重新输入"
    re_weight = torch.norm(weight, dim=1, p=2)
    return re_weight


def get_l2_norm_by_column(weight):
    assert weight.shape != 2, "维度不为2，请重新输入"
    re_weight = torch.norm(weight, dim=0, p=2)
    return re_weight


def get_hist_graph(data):
    # 绘制直方图
    plt.figure(figsize=(10, 6))
    plt.hist(data.detach().numpy(), bins=30, edgecolor='black', alpha=0.7)  # 将Tensor转为numpy数组
    plt.title('Distribution of Tensor Elements', fontsize=15)
    plt.xlabel('Value', fontsize=12)
    plt.ylabel('Frequency', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.show()


def get_top_k(weight, top_k):
    values, indices = torch.topk(weight, top_k, largest=True, sorted=True)
    sorted_indices, sorted_order = torch.sort(indices)
    return values, sorted_indices


def get_bottle_k(weight, top_k):
    _, indices = torch.topk(weight, top_k, largest=False, sorted=True)
    sorted_indices, sorted_order = torch.sort(indices)
    return sorted_indices


def calculate_posibility(data, large_range, low_range=None):
    # 计算均值和标准差
    mean = torch.mean(data)
    std = torch.std(data)

    # 创建正态分布对象
    normal_dist = Normal(mean, std)

    # 计算P(X < 0)
    prob_large = normal_dist.cdf(torch.tensor(large_range))

    if low_range is None:
        prob_low = 0
    else:
        prob_low = normal_dist.cdf(torch.tensor(low_range))

    return prob_large - prob_low
