"""
Gradient Computation Helpers

Utility functions for extracting gradients from differentiable losses.
"""

import torch

def get_gradient_from_diffloss(X, Y, loss, which='both'):
    """
    Extract gradients of a scalar loss with respect to point clouds.
    
    Computes gradients via automatic differentiation for one or both inputs.
    
    Args:
        X: Source point cloud, shape (N, D_x)
        Y: Target point cloud, shape (M, D_y)
        loss: Scalar differentiable loss (torch.Tensor with grad_fn)
        which: Specifies which gradients to compute:
               - 'x': Return gradient w.r.t X only
               - 'y': Return gradient w.r.t X only
               - 'both': Return both gradients
    
    Returns:
        torch.Tensor: Gradient of loss w.r.t X if which='x', shape (N, D_x)
        torch.Tensor: Gradient of loss w.r.t Y if which='y', shape (M, D_y)
        tuple: with both gradient if which='both'
    """
    if which in ['x', 'both']:
        X = X.clone().requires_grad_(True)
    if which in ['y', 'both']:
        Y = Y.clone().requires_grad_(True)

    gradX = torch.autograd.grad(loss, X)[0] if which in ['x', 'both'] else None
    gradY = torch.autograd.grad(loss, Y)[0] if which in ['y', 'both'] else None

    if which == 'x':
        output = gradX
    elif which == 'y':
        output = gradY
    else:
        output = (gradX, gradY)

    return output
