from .resnet import CBA
import torch.nn as nn


# Simplified VGG following the implementation from "Rethinking the value of network pruning"
# https://github.com/Eric-mingjie/rethinking-network-pruning/blob/master/cifar/lottery-ticket/weight-level/models/cifar/vgg.py
class VGGSimple(nn.Module):
    def __init__(
        self, model_cfg=None, n_size=2, n_classes=10, input_size=3, scale=1,
        bn=True  # , dropout=0.5
    ):
        """
            n_size: one of [0,1,2,3] for vgg-11,13,16,19
        """
        super().__init__()

        n_repeats_all = {
            0: [1, 1, 2, 2, 2],
            1: [2, 2, 2, 2, 2],
            2: [2, 2, 3, 3, 3],
            3: [2, 2, 4, 4, 4],
        }
        widths = [64, 128, 256, 512, 512]

        if n_size not in n_repeats_all:
            print(f"Unsupported value n_size ({n_size}), can only be: 0,1,2,3")
            exit(-2)
        n_repeats = n_repeats_all[n_size]

        blocks = []

        prev_bw = input_size
        for i, (n_blocks, bw_) in enumerate(zip(n_repeats, widths)):
            bw = int(bw_*scale)
            for _ in range(n_blocks):
                blocks.append(
                    CBA(model_cfg, prev_bw, bw, 3, bn=bn, bias=False)
                )
                prev_bw = bw
            if i == len(widths)-1:  # Finished building the network
                break
            blocks.append(
                nn.MaxPool2d(2, 2)
            )
        self.blocks = nn.Sequential(*blocks)

        self.top = model_cfg.fc(bw, n_classes)

    def forward(self, nx):
        nx = self.blocks(nx)
        nx = nn.AvgPool2d(2)(nx)
        nx = nx.view(nx.size(0), -1)
        nx = self.top(nx)
        return nx
