import torch


# Code inspired on https://github.com/MatthewAlexanderFisher/MTKSD/
# weighted KSD U-statistic using inverse multiquadric kernel

def weighted_KSD(samples, score_func, weights, gamma=1):
    N, d = samples.size()  # number of samples, dimension
    g = (1 / gamma) ** 2

    scores = score_func(samples)
    s1 = scores.repeat(1, N).view(N * N, d)
    s2 = scores.repeat(N, 1)

    diffs = (samples.unsqueeze(1) - samples).reshape(N * N, d)
    dists = torch.cdist(samples, samples).flatten() ** 2

    k = (1 + g * dists) ** (-1 / 2)
    k_x = -g * (1 + g * dists[:, None]) ** (-3 / 2) * diffs
    k_xy = -3 * g ** 2 * dists * (1 + g * dists) ** (-5 / 2) + g * d * (1 + g * dists) ** (-3 / 2)

    outvec = k * torch.sum(s1 * s2, dim=-1) + torch.sum(-s1 * k_x, dim=-1) + torch.sum(s2 * k_x, dim=-1) + k_xy

    outvec = outvec*weights
    output = torch.sum(outvec)

    return output