import re
import torch.nn as nn

from .feature_storage import FeatureStorage


def register_feature_cache_hooks(model: nn.Module, feat_names: str, features: FeatureStorage):
    hooks = {}
    for module_name, module in model.named_modules():
        if re.search(feat_names, module_name):
            if module_name.startswith('_fsdp_wrapped_module'):
                module_name = ".".join(module_name.split('.')[1:])
            def cache_output_hook(name, features):
                def hook(module, inputs, outputs):
                    if features.enabled:
                        # if output is a list - save only first item in the output
                        if isinstance(outputs, (tuple, list)):
                            outputs = outputs[0]
                        features[name] = outputs
                return hook
            hooks[module_name] = module.register_forward_hook(
                cache_output_hook(module_name, features)
            )
    return hooks
    