from torch import nn

import cell
import operations


class Net(nn.Module):
    def __init__(self, num_classes, hparams):
        super().__init__()

        genotypes = hparams["cells"]
        self.arch = hparams["architecture"]

        filter_multiplier = hparams["filter_multiplier"]
        block_multiplier = len(genotypes[0])
        filters = filter_multiplier * block_multiplier
        affine = hparams["affine"]

        self.drop_path = 0
        self.stem = nn.Sequential(
            nn.Conv2d(3, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters, affine=affine),
        )
        self.cells = nn.ModuleList()

        filter_list = [filters, filters]

        for cell_gen, l in zip(genotypes, self.arch):
            cur_filters = filter_multiplier * 2**l
            filter_list.append(block_multiplier * cur_filters)
            self.cells.append(
                cell.RetrainCell(*filter_list[-3:], cur_filters, cell_gen, affine)
            )

        self.aux_decoder = nn.Sequential(
            nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),
            nn.ReLU(inplace=False),
            nn.Conv2d(filter_list[-(len(self.arch) // 3)], 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=False),
            nn.Conv2d(128, 768, 2, bias=False),
            nn.BatchNorm2d(768),
            operations.View(
                (-1, 768),
            ),
            nn.Linear(768, num_classes),
        )
        self.decoder = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            operations.View(
                (-1, filter_list[-1]),
            ),
            nn.Linear(filter_list[-1], num_classes),
        )

        print(
            "decoder params: {}".format(
                sum(p.numel() for p in self.decoder.parameters())
            )
        )
        print(
            "aux_decoder params: {}".format(
                sum(p.numel() for p in self.aux_decoder.parameters())
            )
        )

    def forward(self, x):
        fmaps = []
        stem = self.stem(x)
        fmaps.append(stem)
        fmaps.append(stem)
        for cell in self.cells:
            fmaps.append(cell(*fmaps[-2:], self.drop_path))

        aux_out = self.aux_decoder(fmaps[-(len(self.arch) // 3)])
        out = self.decoder(fmaps[-1])
        return out, aux_out
