import torch
from torch.special import gammaln


def collect_grad(model):
    grads = []
    device = torch.device("cuda:0")
    with torch.no_grad():
        for param in model.parameters():
            grad = param.grad.flatten().detach()
            grad = grad.to(device)
            grads.append(grad)
    return torch.concat(grads)


def gg_param(param):
    "计算广义高斯参数"
    gt = torch.load("/mnt/hbnas/home/wujun/backslash/data/gamma_table.pt")
    rgt = torch.load("/mnt/hbnas/home/wujun/backslash/data/r_gamma_table.pt")
    param = param.cpu()

    n = param.shape[0]
    var = torch.sum(torch.pow(param, 2))
    mean = torch.sum(torch.abs(param))

    r_gamma = (n * var / mean**2).to(device=torch.device("cpu"))
    pos = torch.argmin(torch.abs(r_gamma - rgt))
    shape = gt[pos]

    scale = torch.sqrt(var) * torch.exp(
        0.5 * (gammaln(1.0 / shape) - gammaln(3.0 / shape))
    )

    return shape, scale


def gg_entropy(scale, shape):
    """计算广义高斯熵"""
    # scale是尺度参数，shape是形状参数
    ln_gamma = gammaln(1.0 / shape)
    entropy = (1.0 / shape) + torch.log(2.0 * scale) + ln_gamma - torch.log(shape)
    return entropy
