from typing import Callable

import torch


def estimate_policy_curvature(
	policy: Callable[[torch.Tensor], torch.Tensor],
	obs: torch.Tensor,
	eps: float = 1e-3,
) -> torch.Tensor:
	"""Estimate the curvature of a policy function.

	This utility approximates the Frobenius norm of the Hessian of the
	policy output with respect to its input observation using finite
	differences.

	Args:
		policy (Callable[[torch.Tensor], torch.Tensor]): Policy function
			mapping observations to actions.
		obs (torch.Tensor): Batch of observations of shape `(B, D)`.
		eps (float, optional): Step size for finite differences.
			Defaults to `1e-3`.

	Returns:
		torch.Tensor: Estimated curvature for each observation.
	"""
	with torch.no_grad():
		base = policy(obs)

		flat_obs = obs.view(obs.size(0), -1)
		curvature = torch.zeros(obs.size(0), device=obs.device)

		for i in range(flat_obs.size(1)):
			delta = torch.zeros_like(flat_obs)
			delta[:, i] = eps
			delta = delta.view_as(obs)

			with torch.no_grad():
				plus = policy(obs + delta)
				minus = policy(obs - delta)
				second_der = (plus + minus - 2 * base) / (eps**2)
				curvature += second_der.pow(2).sum(dim=1)

	return curvature.sqrt()
