"""VGG11/13/16/19 in Paddle."""
import paddle
import paddle.nn as nn


cfg = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


class VGG(nn.Layer):
    def __init__(self, vgg_name, num_classes=10):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = paddle.fluid.layers.reshape(out, shape=[paddle.shape(out).numpy()[0], -1])
        out = self.classifier(out)
        return out
    
    def get_feature(self, x):
        return self.features(x)

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == "M":
                layers += [nn.MaxPool2D(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2D(in_channels, x, kernel_size=3, padding=1),
                    nn.BatchNorm2D(x),
                    nn.ReLU(),
                ]
                in_channels = x
        layers += [nn.AvgPool2D(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


class GaussianClassifier(paddle.nn.Layer):
    def __init__(self, num_classes=10, hidden_dim=48 * 4 * 4):
        super(GaussianClassifier, self).__init__()
        self.embed = nn.Embedding(num_classes, hidden_dim)
    
    def forward(self, x):
        dist = paddle.square(x.unsqueeze(1) - self.embed.weight.unsqueeze(0)).sum(axis=-1)  # b * 10
        return dist


class AutoEncoder_VGG(VGG):

    def __init__(self, vgg_name, num_classes=10):
        super(VGG, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2D(3, 12, 4, stride=2, padding=1),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2D(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2D(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
			nn.Conv2DTranspose(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2DTranspose(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2DTranspose(12, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]
            nn.Sigmoid(),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            # GaussianClassifier(num_classes=num_classes, hidden_dim=48 * 4 * 4),
            nn.Linear(48 * 4 * 4, num_classes)
        )

    def forward(self, x):
        f = self.encoder(x)
        pred = self.classifier(f)
        x_hat = self.decoder(f)
        return pred, x_hat
    
    def get_feature(self, x):
        return self.encoder(x)

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == "M":
                layers += [nn.MaxPool2D(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2D(in_channels, x, kernel_size=3, padding=1),
                    nn.BatchNorm2D(x),
                    nn.ReLU(),
                ]
                in_channels = x
        layers += [nn.AvgPool2D(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG("VGG11")
    x = paddle.randn(2, 3, 32, 32)
    y = net(x)
    print(y.size())


# test()
