import torch.nn.init as init
from torch import nn
import torch

def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


class MNISTEncoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        self.main = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, z_dim),
        )
        self.weight_init()

    def weight_init(self):
      for block in self._modules:
          for m in self._modules[block]:
              kaiming_init(m)

    def forward(self, x):
        x = self.main(x)
        return x

class MNISTDecoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        self.main = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid(),
            nn.Unflatten(dim=1, unflattened_size=(1, 28, 28)),
        )

        self.weight_init()

    def forward(self, x):
        x = self.main(x)
        return x

    def weight_init(self):
      for block in self._modules:
          for m in self._modules[block]:
              kaiming_init(m)

class MNIST_NN_CornerCrop(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        self.main = nn.Sequential(
            nn.Linear(z_dim, 1024, bias=False),
            nn.ReLU(),
            nn.Linear(1024,14*14, bias=True),
            nn.Unflatten(dim=1, unflattened_size=(1, 14, 14)),
        )
        self.weight_init()

    def weight_init(self):
      for block in self._modules:
          for m in self._modules[block]:
              kaiming_init(m)

    def forward(self, x):
        out = self.main(x)
        return out
