from .comlib import *

# torch.func.vjp
# torch.autograd.functional.vjp

def compute_jacobian(func,inputs):
    '''
    The inputs given to jacobian must be either a Tensor or a tuple of Tensors
    '''
    return torch.autograd.functional.jacobian(func, inputs, create_graph=False, strict=False, vectorize=False, strategy='reverse-mode')
    
def compute_jacobian_dot_grad(func,input_tuple,grad_tuple):
    '''
    input_tuple: tensors, requires_grad=True
    grad_tuple: tensors
    '''
    _,jvp=torch.func.jvp(func, input_tuple, grad_tuple)
    jvp=jvp.detach()
    # _,jvp=torch.autograd.functional.jvp(func, grad_obj,vecs, create_graph=False, strict=False)
    return jvp


def compute_jacobian_dot_grads(func,input_tuple,grad_tuples):   
    '''
    input_tuple: tensors, requires_grad=True
    grad_tuple: tensors
    '''       
    def jvp1(grad_tuple):
        return torch.func.jvp(func, input_tuple, grad_tuple)
    loss_value,jvp=torch.vmap(jvp1)(grad_tuples)
    jvp2=jvp.detach()
    del loss_value,jvp
    # LOGGER.debug(loss_value)
    # LOGGER.debug("jvp: "+str(jvp.shape)+str(jvp))
    return jvp2