import torch


def weighted_sd_mean(x, w):
    mean = torch.sum(w * x) / torch.sum(w)
    # sd = torch.sqrt(torch.sum(w * (x - mean) ** 2) / (torch.sum(w) - 1))
    sd = torch.sqrt(
        torch.sum(w * (x - mean) ** 2) / ((w.shape[0] - 1) * torch.sum(w) / w.shape[0])
    )
    return sd, mean


def weighted_scale(x, w, max_shift=False, mean_only=False):
    # x = x.detach()
    # w = w.detach()
    # with torch.no_grad():
    sd, mean = weighted_sd_mean(x, w)
    x = x - mean
    if not mean_only:
        x = x / sd
    if max_shift:
        x = x - torch.max(x)
    return x


def standardize(x):
    with torch.no_grad():
        x = (x - torch.mean(x)) / torch.std(x)
    return x
