from torch import Tensor


def mask_post_accumulate_hook(param: Tensor, mask: Tensor) -> None:
    param.grad = param.grad * mask  # type: ignore


def mask_grad_hook(grad: Tensor, mask: Tensor) -> Tensor:
    return grad * mask


def get_input_hook(name: str, activations: dict, procees_fn=None):
    """Creates a forward hook to capture the input activations of a layer."""

    def hook(model, input, output):
        if name not in activations:
            activations[name] = []
        if procees_fn is not None:
            input = procees_fn(input)
        activations[name].append(input)

    return hook


def get_output_hook(name: str, activations: dict):
    """Creates a forward hook to capture the output activations of a layer."""

    def hook(model, input, output):
        if name not in activations:
            activations[name] = []
        activations[name].append(output)

    return hook


def get_forward_hook(
    name: str,
    activations: dict,
    input_process_fn=None,
    output_process_fn=None,
):
    """Creates a forward hook to capture the input and output activations of a layer."""

    def hook(model, input, output):
        # print(f"Hook triggered for layer: {model}")
        if name not in activations:
            activations[name] = {"input": [], "output": []}
        if input_process_fn is not None:
            input = input_process_fn(input)
        if output_process_fn is not None:
            output = output_process_fn(output)
        activations[name]["input"].append(input)
        activations[name]["output"].append(output)

    return hook
