import torch
from utils import marginal_prob_std


def score_matching_loss(model, x, sigma, eps=1e-5):
    device = x.device
    B = x.size(0)

    random_t = torch.rand(B, device=device) * (1 - eps) + eps
    z = torch.randn_like(x)

    std = marginal_prob_std(random_t, sigma).to(device)
    perturbed_x = x + std[:, None] * z

    score = model(perturbed_x, random_t)
    loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=1))
    return loss
