from .resnet import CBA
import torch.nn as nn


class ConvLot(nn.Module):
    def __init__(
        self, model_cfg=None, n_size=2, n_classes=10, input_size=3, scale=1,
        bn=False, im_dim=32
    ):
        """
            n_size: one of [2,4,6] for conv-2,-4,-6
        """
        super().__init__()

        assert n_size in [2, 4, 6]

        blocks = []

        prev_bw = input_size
        bw = int(scale*64)

        for i in range(n_size//2):
            blocks.append(nn.Sequential(
                CBA(model_cfg, prev_bw, bw, 3, bn=False),
                CBA(model_cfg, bw, bw, 3, bn=False),
                nn.MaxPool2d(2, 2)
            ))
            prev_bw = bw
            bw *= 2

        self.blocks = nn.Sequential(*blocks)

        # Account for flattening
        prev_bw = prev_bw * (im_dim // 2**(n_size//2))**2
        bw = int(256*scale)
        self.fc1 = model_cfg.fc(prev_bw, bw)
        self.fc2 = model_cfg.fc(bw, bw)
        self.fc3 = model_cfg.fc(bw, n_classes)

    def forward(self, nx):
        nx = self.blocks(nx)
        nx = nx.view(nx.size(0), -1)
        nx = self.fc3(self.fc2(self.fc1(nx)))
        return nx
