import geoopt
from torch import nn


def add_model_args(parser):
    pass


class Encoder(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 200),
            nn.Tanh(),
            nn.Linear(200, 200),
            nn.Tanh()
        )
        self.output_dim = 200

    def forward(self, x):
        feature = self.encoder(x)
        return feature


class Decoder(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.latent_dim = args.latent_dim * 2

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 200),
            nn.Tanh(),
            nn.Linear(200, 200),
            nn.Tanh(),
            nn.Linear(200, 28 * 28),
        )

    def forward(self, z):
        z = self.decoder(z)
        z = z.view(*z.shape[:-1], 1, 28, 28)
        return z

