import torch


def norm_gradient_squared(outputs, inputs, sum_over_points=True):
    """Computes square of the norm of the gradient of outputs with respect to
    inputs.

    Args:
        outputs (torch.Tensor): Shape (batch_size, 1). Usually the output of
            discriminator on real data.
        inputs (torch.Tensor): Shape (batch_size, num_points, coordinate_dim + feature_dim)
            or shape (batch_size, num_points, feature_dim) depending on whether gradient
            is over coordinates and features or only features.
        sum_over_points (bool): If True, sums over num_points dimension, otherwise takes mean.

    Notes:
        This is inspired by the function in this repo
        https://github.com/LMescheder/GAN_stability/blob/master/gan_training/train.py
    """
    batch_size, num_points, _ = inputs.shape
    # Compute gradient of outputs with respect to inputs
    grad_outputs = torch.autograd.grad(
        outputs=outputs.sum(), inputs=inputs,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    # Square gradients
    grad_outputs_squared = grad_outputs.pow(2)
    # Return norm of squared gradients for each example in batch. We sum over
    # features, to return a tensor of shape (batch_size, num_points).
    regularization = grad_outputs_squared.sum(dim=2)
    # We can now either take mean or sum over num_points dimension
    if sum_over_points:
        return regularization.sum(dim=1)
    else:
        return regularization.mean(dim=1)
