import torch
import FrEIA


class GlowInvertibleNetwork(torch.nn.Module):
    """
    Example:
        >>> model = GlowInvertibleNetwork(5)
        >>> x = torch.rand(12, 5)
        >>> y = model(x)
        >>> y.shape
        torch.Size([12, 5])
        >>> x_inv = model(y, rev=True)
        >>> torch.isclose(x, x_inv).all()
        tensor(True)
    """

    def __init__(self, dim, layer_num=8, hidden_dim=8):
        super().__init__()

        def subnet_fc(c_in, c_out):
            return torch.nn.Sequential(torch.nn.Linear(c_in, hidden_dim), torch.nn.ReLU(),
                                       torch.nn.Linear(hidden_dim, c_out))
        nodes = [FrEIA.framework.InputNode(dim, name='input')]

        for k in range(layer_num):
            nodes.append(FrEIA.framework.Node(nodes[-1],
                                              FrEIA.modules.GLOWCouplingBlock,
                                              {'subnet_constructor': subnet_fc, 'clamp': 2.0},
                                              name=f'coupling_{k}'))
            nodes.append(FrEIA.framework.Node(nodes[-1],
                                              FrEIA.modules.PermuteRandom,
                                              {'seed': k},
                                              name=f'permute_{k}'
                                              ))

        nodes.append(FrEIA.framework.OutputNode(nodes[-1], name='output'))
        self.net = FrEIA.framework.ReversibleGraphNet(nodes, verbose=False)

    def forward(self, x, rev=False):
        return self.net(x, rev=rev)


if __name__ == '__main__':
    import doctest
    doctest.testmod()
