from collections import defaultdict

from torch import nn

import cell
import operations


class SuperNet(nn.Module):
    def __init__(self, num_classes, hparams):
        super().__init__()

        self.arch = hparams["architecture"]

        filters = hparams["filter_multiplier"]
        self.block_multiplier = hparams["block_multiplier"]
        self.filters = filters
        filter_list = [filters, filters]

        self.cells = nn.ModuleList()
        self.genotypes = nn.ModuleList()

        self.stem = nn.Sequential(
            nn.Conv2d(3, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters, affine=False),
        )

        self.red = nn.ModuleList(
            [
                operations.ResNetBasicblock(filters, 2 * filters, 2),
                operations.ResNetBasicblock(2 * filters, 4 * filters, 2),
            ]
        )

        for l in self.arch:
            curr_filters = filters * 2**l
            filter_list.append(curr_filters)
            self.genotypes.append(
                cell.NasBenchGenotype(
                    curr_filters,
                    PS=hparams["PS"],
                )
            )
            self.cells.append(
                cell.NasBenchSearchCell(
                    self.block_multiplier,
                    curr_filters,
                )
            )

        self.decoder = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            operations.View(
                (-1, filter_list[-1]),
            ),
            nn.Linear(filter_list[-1], num_classes),
        )

    def forward(self, x, collect_stats=False):
        fmaps = self.stem(x)
        coeffs_dict = defaultdict(list)
        l_prev = None
        for i, (cell, l) in enumerate(zip(self.cells, self.arch)):
            if (l_prev, l) == (0, 1):
                fmaps = self.red[0](fmaps)
            if (l_prev, l) == (1, 2):
                fmaps = self.red[1](fmaps)

            fmaps, curr_coeffs_dict = cell(
                fmaps,
                self.genotypes[i],
                collect_stats=collect_stats,
            )
            l_prev = l

            if collect_stats:
                curr_coeffs_dict["cell_idx"] = [i] * len(curr_coeffs_dict["coeff"])
                curr_coeffs_dict["cell_level"] = [l] * len(curr_coeffs_dict["coeff"])
                for key, val in curr_coeffs_dict.items():
                    coeffs_dict[key].extend(val)

        out = self.decoder(fmaps)

        return out, coeffs_dict

    def arch_parameters(self):
        yield from self.genotypes.parameters()

    def weight_parameters(self):
        yield from self.stem.parameters()
        yield from self.cells.parameters()
        yield from self.decoder.parameters()
