"""
calculates the vector jacobian products required to calculate the adjoint integral,
calculates the integrands at the prespecified points for quadrature calculation
"""

import torch


def calc_vjps(func, t, aug_state, adjoint_params, t_requires_grad):
    with torch.enable_grad():
        y = aug_state[0].squeeze(0)
        adj_y = aug_state[1].squeeze(0)
        t_ = -t.detach().to(y.dtype)  # have a minus because it goes backward in time, so all times are multiplied by -1
        t = t_.requires_grad_(True)
        y = y.detach()
        func_eval = func(t if t_requires_grad else t_, y)
        
        # calculate the vjps
        *vjp_params, vjp_t = torch.autograd.grad(
                        func_eval, [*adjoint_params, t], adj_y,
                        allow_unused=True, retain_graph=False)

        # autograd.grad returns None if no gradient, set to zero.
        vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
        vjp_params = [torch.zeros_like(param) if vjp_params is None else vjp_params
                        for param, vjp_params in zip(adjoint_params, vjp_params)]


    return vjp_params, vjp_t