
class FeatureExtractor:
    def __init__(self, model):
        self.model = model
        self.layers = []
        self.hooks = {}
        self.activations = {}

    def _get_hook(self, name):
        def hook(model, input, output):
            self.activations[name] = output.detach()
        return hook

    def register_hooks(self, layers):
        if not layers:
            raise ValueError('register_hooks needs model layers to extract!')

        for layer in layers:
            assert layer in self.model.layers, f'Layer {layer} not in the model!'
        self.layers = layers

        for name, module in self.model.model.named_modules():
            if len(list(module.children())) > 0:
                continue
            else:
                if name in layers:
                    hook_handle = \
                            module.register_forward_hook(self._get_hook(name))
                    self.hooks[name] = hook_handle
