import torch

class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.enc = torch.nn.Sequential(
            torch.nn.Flatten(1),
            torch.nn.Linear(28*28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.enc(x)
    
class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.dec = torch.nn.Sequential(
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28*28),
            torch.nn.Unflatten(1, (1, 28, 28)),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.dec(x)
    
class Getter(torch.nn.Module):
    def __init__(self, enc: Encoder):
        super().__init__()

        self.enc = enc
        self.enc.requires_grad_(False)

        self.get = torch.nn.Sequential(
            torch.nn.Linear(32, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10),
            torch.nn.Softmax(-1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.get(self.enc(x))
    
class Putter(torch.nn.Module):
    def __init__(self, enc: Encoder, dec: Decoder):
        super().__init__()

        self.enc = enc
        self.enc.requires_grad_(False)
        self.dec = dec
        self.dec.requires_grad_(False)

        self.put = torch.nn.Sequential(
            torch.nn.Linear(32 + 10, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 32),
            torch.nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
        return self.dec(self.put(torch.cat((self.enc(x), l), dim=-1)))
