import numpy as np
import torch

def jac_loss_estimate(f0, z0, vecs=2, create_graph=False):
    """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator
    Args:
        f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
        z0 (torch.Tensor): Input to the function f
        vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2.
        create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss). 
                                       Defaults to True.
    Returns:
        torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss
    """
    vecs = vecs
    result = 0
    for i in range(vecs):
        v = torch.randn_like(f0,device=f0.device,dtype=f0.dtype)
        vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=True)[0]
        result += vJ.norm()#**2
    return result / vecs / torch.prod(torch.tensor(z0.shape))


def jac_loss_exact(f0, z0, vecs=2, create_graph=False):
    """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator
    Args:
        f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
        z0 (torch.Tensor): Input to the function f
        vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2.
        create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss). 
                                       Defaults to True.
    Returns:
        torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss
    """
    vecs = vecs
    result = 0
    for i in range(vecs):
        v = torch.randn(*z0.shape).to(z0).requires_grad_(True)
        vJ = torch.autograd.grad(f0, z0, v, retain_graph=True)[0].requires_grad_(True)
        result += vJ.norm()**2
    return result / vecs / torch.prod(torch.tensor(z0.shape))