import torch


class SVGD(object):
    def __init__(self, init_particles, log_density, kernel, alpha=1.):
        self.log_density = log_density
        self.kernel = kernel
        self.alpha = alpha

        self.N, self.D = init_particles.shape
        self._particles = init_particles.detach().clone().requires_grad_(True)

    def optim_parameters(self):
        return self._particles

    def particles(self):
        return self._particles.detach()

    def reset(self, particles):
        self._particles.data = particles.clone()

    def calc_weights(self):
        log_px = self.log_density(self._particles)
        w = log_px - log_px.max()
        w = torch.exp(w) / torch.exp(w).sum()
        return w.detach()

    def calc_grad_norm(self, mean=False):
        grad_norm = (self._particles.grad * self._particles.grad).sum(-1)
        if mean:
            grad_norm = grad_norm.sum() / self.N
        return grad_norm.detach()

    def calc_grad_percentage(self, mean=False):
        grad_norm = (self._particles.grad * self._particles.grad).sum(-1)
        particle_norm = (self._particles * self._particles).sum(-1)
        percent = (grad_norm / particle_norm)
        if mean:
            percent = percent.mean()
        return percent.detach()

    def update(self, normalize=False, create_graph=False):
        log_px, self._particles.grad = self.calc_grads(self._particles, create_graph=create_graph)

        if normalize:
            w = log_px - log_px.max()
            w = torch.exp(w) / torch.exp(w).sum()

            return w.detach()

        return log_px

    def grad_log_density(self, x, create_graph=False):
        x = x.detach().requires_grad_(True)
        log_px = self.log_density(x)
        d_log_px, = torch.autograd.grad(log_px.sum(), x, create_graph=create_graph)
        return log_px, d_log_px

    def grad_kernel(self, x, y, create_graph=False):
        x = x.detach().requires_grad_(True)
        y = y.detach().requires_grad_(True)

        k_xy = self.kernel(x, y)
        d_kx, d_ky = torch.autograd.grad(k_xy.sum(), [x, y], create_graph=create_graph)

        return k_xy.detach(), d_kx, d_ky

    def calc_grads(self, x, create_graph=False):
        N = x.size(0)

        log_px, d_log_px = self.grad_log_density(x, create_graph=create_graph)
        k_xx, d_kxx, d_kxy = self.grad_kernel(x, x, create_graph=create_graph)

        phi = (torch.matmul(k_xx, d_log_px) - self.alpha * d_kxx) / N  # Not sure why second term is negative.
        return log_px, -phi
