import torch


class ForwardHook:
    def __init__(self, module):
        """
        a nice forward hook
        """
        self.hook = module.register_forward_hook(self.hook_fn)
        self.input = None
        self.output = None

    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output

    def close(self):
        self.hook.remove()
