import torch 

def denoising_score_loss(s, x, sigma):
    x_perturb = x + sigma * torch.randn_like(x)
    return (torch.linalg.norm(sigma * s(x_perturb) - (1/sigma)*(x - x_perturb), dim = 1)**2).mean()

def sliced_score_loss(s, x, M):
    x = x.unsqueeze(0).expand(M, *x.shape) # (n_slices, b, ...)
    x = x.contiguous().view(-1, *x.shape[2:]) # (n_slices*b, ...)
    v = torch.randn_like(x)
    _s, _dsv = torch.func.jvp(s, (x, ), (v, ))
    loss1 = (_dsv * v).sum(1) # v'*Ds*v
    loss2 = (_s**2).sum(1) / 2
    return (loss1 + loss2).mean()

def wasserstein_fisher_rao_energy(v, g, p, alpha = 0.1):
    # return torch.sum(((1-alpha)*torch.sum(v**2, axis = list(range(1, v.ndim))) + alpha*(g**2).squeeze()) * p.squeeze())
    loss1=torch.sum(torch.sum(v**2, axis = list(range(1, v.ndim))) * p.squeeze())
    loss2=torch.sum(torch.sum(g**2, axis = list(range(1, g.ndim))) * p.squeeze())
    return loss1 + alpha*loss2 

def fisher_rao_energy(g, p):
    return torch.sum((g**2) * p)

def benamou_brenier_energy(v, p):
    return torch.sum(torch.sum(v**2, axis = list(range(1, v.ndim))) * p)
