import torch


class CustomFn(torch.nn.Module):

    def __init__(self, fn):
        super().__init__()
        self._fn = fn

    def forward(self, X):
        return self._fn(X)