import torch

from models.base_model import SequentialModel


def get_vector_of_params(model: SequentialModel) -> torch.Tensor:
    with torch.no_grad():
        # get a list of flattened parameters
        thetas = []
        for layer in model.layers:
            if hasattr(layer, "weight"):
                thetas.append(torch.flatten(layer.weight))
        # return flattened vector
        return torch.cat(thetas)
