import torch
from torch import nn

import operations
import utils

OPS = {
    0: lambda filters, affine, stride: nn.AvgPool2d(
        3, stride=stride, padding=1, count_include_pad=False
    ),
    1: lambda filters, affine, stride: nn.MaxPool2d(3, stride=stride, padding=1),
    2: lambda filters, affine, stride: (
        operations.Identity()
        if stride == 1
        else operations.FactorizedReduce(filters, filters, affine)
    ),
    3: lambda filters, affine, stride: nn.Sequential(
        operations.SepConvReLUBN(filters, filters, 3, 1, affine=affine, stride=stride),
        operations.SepConvReLUBN(filters, filters, 3, 1, affine=affine),
    ),
    4: lambda filters, affine, stride: nn.Sequential(
        operations.SepConvReLUBN(filters, filters, 5, 2, affine=affine, stride=stride),
        operations.SepConvReLUBN(filters, filters, 5, 2, affine=affine),
    ),
    5: lambda filters, affine, stride: operations.SepConvReLUBN(
        filters, filters, 3, 2, 2, affine=affine, stride=stride
    ),
    6: lambda filters, affine, stride: operations.SepConvReLUBN(
        filters, filters, 5, 4, 2, affine=affine, stride=stride
    ),
}


class RetrainCell(nn.Module):
    """DARTS cell, retrain stage."""

    def __init__(
        self, pp_filters, p_filters, filters, filter_multiplier, genotype, affine
    ):
        super().__init__()

        self.filter_multiplier = filter_multiplier

        self.pre_pp = operations.ConvReLUBN(
            pp_filters, filter_multiplier, 1, 0, affine=affine
        )
        self.pre_p = operations.ConvReLUBN(
            p_filters, filter_multiplier, 1, 0, affine=affine
        )

        red_ids = set()
        if pp_filters != filters:
            red_ids.add(0)
        if p_filters != filters:
            red_ids.add(1)

        self.ops = nn.ModuleList()
        self.genotype = []

        for (i1, o1), (i2, o2) in genotype:

            o1 = OPS[o1](filter_multiplier, affine, stride=(2 if i1 in red_ids else 1))
            o2 = OPS[o2](filter_multiplier, affine, stride=(2 if i2 in red_ids else 1))
            self.ops.append(o1)
            self.ops.append(o2)
            self.genotype.append(
                ((i1, o1), (i2, o2)),
            )

    def forward(self, pp, p, drop_path):
        states = []
        states.append(self.pre_pp(pp))
        states.append(self.pre_p(p))
        for (i1, o1), (i2, o2) in self.genotype:
            h1 = o1(states[i1])
            h2 = o2(states[i2])
            if self.training and drop_path > 0:
                h1 = utils.drop(h1, drop_path)
                h2 = utils.drop(h2, drop_path)
            states.append(h1 + h2)

        return torch.cat(states[2:], dim=1)
