from collections import defaultdict

from torch import nn

import cell
import operations


class SuperNet(nn.Module):
    def __init__(self, num_classes, hparams):
        super().__init__()

        self.arch = hparams["architecture"]

        filter_multiplier = hparams["filter_multiplier"]
        block_multiplier = hparams["block_multiplier"]
        self.cells = nn.ModuleList()
        self.genotypes = nn.ModuleList()

        filters = filter_multiplier * block_multiplier
        filter_list = [filters, filters]
        self.stem = nn.Sequential(
            nn.Conv2d(3, filters, 3, padding=1, bias=False),
            nn.BatchNorm2d(filters, affine=False),
        )
        for l in self.arch:
            curr_filters = filter_multiplier * 2**l
            self.genotypes.append(
                cell.DartsGenotype(
                    curr_filters,
                    PS=hparams["PS"],
                )
            )
            filter_list.append(block_multiplier * curr_filters)
            self.cells.append(
                cell.DartsSearchCell(
                    block_multiplier,
                    *filter_list[-3:],
                    curr_filters,
                )
            )

        self.decoder = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            operations.View(
                (-1, filter_list[-1]),
            ),
            nn.Linear(filter_list[-1], num_classes),
        )

    def forward(self, x, collect_stats=False):
        fmaps = []
        stem = self.stem(x)
        fmaps.append(stem)
        fmaps.append(stem)
        coeffs_dict = defaultdict(list)
        for i, (cell, l) in enumerate(zip(self.cells, self.arch)):
            curr_fmaps, curr_coeffs_dict = cell(
                *fmaps[-2:],
                self.genotypes[i],
                collect_stats=collect_stats,
            )
            fmaps.append(curr_fmaps)

            if collect_stats:
                curr_coeffs_dict["cell_idx"] = [i] * len(curr_coeffs_dict["coeff"])
                curr_coeffs_dict["cell_level"] = [l] * len(curr_coeffs_dict["coeff"])
                for key, val in curr_coeffs_dict.items():
                    coeffs_dict[key].extend(val)

        out = self.decoder(fmaps[-1])

        return out, coeffs_dict

    def arch_parameters(self):
        yield from self.genotypes.parameters()

    def weight_parameters(self):
        yield from self.stem.parameters()
        yield from self.cells.parameters()
        yield from self.decoder.parameters()
