import torch
from typing import Optional


def remove_center_of_gravity(x: torch.Tensor) -> torch.Tensor:
    assert (
        x.ndim == 3
    ), f"Expected input with 3 dimensions (batch, atoms, dim), got {x.ndim}"
    center_of_gravity = x.mean(dim=1, keepdim=True)
    x_centered = x - center_of_gravity
    return x_centered


def get_trace_jacobian(
    ut: torch.Tensor,
    x: torch.Tensor,
    exact: bool = False,
    z: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Compute the divergence (trace of the Jacobian of ut w.r.t x).

    Args:
        ut (Tensor): Output of the velocity model. Shape [batch_size, dim]
        x (Tensor): Input tensor to differentiate with respect to. Shape [batch_size, dim]
        exact (bool): Whether to compute the exact divergence (expensive but accurate).
        z (Tensor, optional): Random vector for Hutchinson estimator. Required if exact=False.

    Returns:
        Tensor: Divergence estimate for each sample. Shape [batch_size]
    """
    if exact:
        div = 0.0
        for i in range(ut.shape[1]):
            grad_ut_i = torch.autograd.grad(
                ut[:, i],
                x,
                grad_outputs=torch.ones_like(ut[:, i]),
                retain_graph=True,
            )[0]
            div += grad_ut_i[:, i]
    else:
        assert (
            z is not None
        ), "Random vector z must be provided for Hutchinson estimator"
        ut_dot_z = torch.einsum("ij,ij->i", ut, z)
        grad_ut_dot_z = torch.autograd.grad(
            ut_dot_z,
            x,
            grad_outputs=torch.ones_like(ut_dot_z),
            retain_graph=True,
        )[0]
        div = torch.einsum("ij,ij->i", grad_ut_dot_z, z)

    return div


def prior_log_density(x):
    return -0.5 * (x**2).sum(dim=-1)
