from typing import List

import torch
from typeguard import check_argument_types


def replace_module(
    parent_module: torch.nn.Module,
    child_name: str,
    old_module: torch.nn.Module,
    new_module: torch.nn.Module,
):
    """Replace the target module with the new module."""
    assert check_argument_types()
    # TODO add hook and whether requires_grad to them
    device = old_module.weight.device
    setattr(parent_module, child_name, new_module)

    # copy weight and bias from the target module
    new_module.weight = old_module.weight
    if hasattr(old_module, "bias") and old_module.bias is not None:
        new_module.bias = old_module.bias

    # move the new_module to the same device as the old_module
    new_module.to(device)


def check_target_module_exists(key: str, target_modules: List[str]):
    """Check if the target_modules matchs the given key."""
    assert check_argument_types()
    return any([key.endswith(target_key) for target_key in target_modules])


def get_submodules(model: torch.nn.Module, key: str):
    """Return the submodules of the given key."""
    assert check_argument_types()
    parent_module = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target_module = model.get_submodule(key)
    return parent_module, target_name, target_module
