import torch

from models.base_model import SequentialModel


def get_grad_vector_of_params(model: SequentialModel) -> torch.Tensor:
    """Computes the flattened vector of the gradient of the loss wrt. to parameters.

    Note
    ----
        Requires a loss backward pass before the call.

    Parameters
    ----------
    model
        Model object.

    Returns
    -------
        The flattened vector of parameter gradients.
    """
    # get a list of norms of gradients of each param
    grads = []
    for layer in model.layers:
        if hasattr(layer, "weight"):
            grads.append(torch.flatten(layer.weight.grad))
    # return flattened gradient vector
    return torch.cat(grads)
