"""VGG11/13/16/19 in Paddle."""
import paddle
import paddle.nn as nn


cfg = {
    "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


class VGG(nn.Layer):
    def __init__(self, vgg_name, num_classes=10):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = paddle.fluid.layers.reshape(out, shape=[paddle.shape(out).numpy()[0], -1])
        out = self.classifier(out)
        return out
    
    def get_feature(self, x):
        return self.features(x)

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == "M":
                layers += [nn.MaxPool2D(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2D(in_channels, x, kernel_size=3, padding=1),
                    nn.BatchNorm2D(x),
                    nn.ReLU(),
                ]
                in_channels = x
        layers += [nn.AvgPool2D(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG("VGG11")
    x = paddle.randn(2, 3, 32, 32)
    y = net(x)
    print(y.size())


# test()
