import torch
from nn_compression._interfaces import quantisable
from nn_compression.networks import recursively_find_named_children


def extract_quant_weights(net_with_quantised_weights: torch.nn.Module) -> dict:
    """Extract the quantised weights from a network that has previously been quantised.
    The function returns a dictionary with the name of the layer as the key and the quantised weight as the value.
    """
    grids = {}
    for name, module in recursively_find_named_children(net_with_quantised_weights):
        if quantisable(module):
            grids[name] = module.weight
    return grids
