import torch.nn as nn

from typing import (
    Dict,
    List,
)


def find_modules(
    module: nn.Module,
    module_types: List = [nn.Linear],
    name: str = '',
) -> Dict[str, nn.Module]:
    """ Finds all submodules of certain types in a given module.

    Args:
        module (nn.Module): The module to search in.
        module_types (List, optional): A list of module types to search for. \
            Defaults to [nn.Linear].
        name (str, optional): The name of the module. Defaults to ''.

    Returns:
        Dict[str, nn.Module]: A dictionary of found submodules.
    """

    # Check whether the module is in the list of module types.
    if type(module) in module_types:
        return {name: module}

    # Check whether the module has submodules in the list of module types.
    submodules = {}
    for submodule_name, submodule in module.named_children():
        submodules.update(
            find_modules(
                module=submodule,
                module_types=module_types,
                name=(name + '.' + submodule_name) \
                    if name != '' else submodule_name,
            )
        )

    return submodules
