import torch
def entropy_loss(dist, target, start, samples):
    dist.condition(start)
    loss = - dist.log_prob(samples)
    return loss

def kl_loss(dist, target, start, samples):
    dist.condition(start)
    loss = - dist.log_prob(samples) + target.log_prob(samples)
    return loss

def log_der_kl_loss(dist, target, start, samples):
    test_samples = samples.detach()
    dist.condition(start)
    loss = - dist.log_prob(test_samples) + target.log_prob(test_samples)
    return (dist.log_prob(test_samples)*loss.detach() + loss)

def acc_rate_loss(dist, target, start, samples):
    dist.condition(start)
    factor = target.log_prob(samples) - target.log_prob(start)
    forward_prob = dist.log_prob(samples)
    dist.condition(samples)
    reverse_prob = dist.log_prob(start)
    return torch.exp(torch.clamp(factor - forward_prob + reverse_prob, max=0))

def log_acc_rate_loss(dist, target, start, samples):
    dist.condition(start)
    factor = target.log_prob(samples) - target.log_prob(start)
    forward_prob = dist.log_prob(samples)
    dist.condition(samples)
    reverse_prob = dist.log_prob(start)
    return torch.clamp(factor - forward_prob + reverse_prob, max=0)


def msjd_loss(start, samples):
    return ((samples-start)**2.0).sum(dim=1)
