import torch

## Discriminator Guidance
    def grad_disc(discriminator, x, t, y = None, **kwargs):
            
        with torch.enable_grad():

            x_in = x.detach().requires_grad_(True)
            pr = discriminator(x_in, t, sigmoid = True, condition=None)
            
            pr = torch.clip(pr, min=1e-5, max=1 - 1e-5)
            log_density_ratio = torch.log(pr) - torch.log(1 - pr)

            dg = torch.autograd.grad(log_density_ratio.sum(), x_in)[0]

            return dg
      
