import torch as tc
from torch.func import jacfwd, jacrev, vmap

def grad(model, x: tc.Tensor, apply_linear=True, create_graph=False, mode="forward"):
    """
    Uses automatic differentiation of torch.autograd to differentiate the given network with respect to
    input x. Note that due to vmap nature the model should not do any in-place operations.

    Use the reverse option when output dimension is smaller than input dimension.

    If apply_linear=True:
        Returns gradient of the model output with respect to given x of the same shape as x (model output is scalar).
    Else:
        Returns gradient of the output of the last dense layer of the model w.r.t to input x, of shape
        (n_points, out_dim, n_features) where x is assumed to have the shape (n_points, n_features)
        and last dense output has shape (n_points, out_dim). This option can be used to fit a linear
        layer to fit the gradient of the network.
    """
    x.requires_grad_()
    if apply_linear:
        y = model.forward(x, apply_linear=True)
        y_deriv = tc.autograd.grad(y.sum(), x, create_graph=create_graph, retain_graph=create_graph)[0]
    else:
        # iterate through output dimensions
        def single_forward(x_single):
            x_single = x_single.unsqueeze(0) # (1, n_features)
            y = model.forward(x_single, apply_linear=False)
            return y.squeeze(0)
        # jacfwd(.) gives the Jacobian for one input
        # vmap(.) efficiently maps this over batch dimension (parallelize)
        if mode == "forward":
            y_deriv = vmap(jacfwd(single_forward))(x)
        elif mode == "reverse":
            y_deriv = vmap(jacrev(single_forward))(x)
        else:
            raise ValueError("Automatic differentiation mode unknown")
    return y_deriv
