import torch.nn as nn
import torch.nn.functional as F
from .resnet import CBA


class SEModule(nn.Module):
    def __init__(self, model_cfg, width, ratio):
        super().__init__()

        mid_width = max(1, int(ratio*width))
        self.excite = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            model_cfg.conv(width, mid_width, 1),  # Bias is intentional
            model_cfg.act(),
            model_cfg.conv(mid_width, width, 1),
            nn.Sigmoid()  # TODO: either add to model_cfg either use hardSigm
        )

    def forward(self, nx):
        _in = nx
        nx = self.excite(nx)
        return nx * _in


class MobileConvBlock(nn.Module):
    def __init__(
        self, model_cfg, in_width, width, ksize, bn=True, act=True,
        widening_ratio=6, se_ratio=0, stride=1
    ):
        super().__init__()

        mid_width = int(in_width * widening_ratio)

        # -- Main path
        main_path = []
        if widening_ratio > 1:
            # Widening Point-wise conv
            main_path.append(CBA(
                model_cfg, in_width, mid_width, 1,
                bn=True, act=True, bias=False
            ))
        # Depthwise conv
        main_path.append(CBA(
            model_cfg, mid_width, mid_width, ksize,
            bn=True, act=True, bias=False, depthwise=True,
            stride=stride
        ))
        # Squeeze & excite module
        if se_ratio > 0:
            main_path.append(
                SEModule(model_cfg, mid_width, se_ratio)
            )
        # Final Point-wise conv
        main_path.append(CBA(
            model_cfg, mid_width, width, 1,
            bn=True, act=False, bias=False
        ))

        self.main_path = nn.Sequential(*main_path)
        # Only use the Identity skip connection
        self.has_skip = (in_width == width) and (stride == 1)

    def forward(self, nx):
        _in = nx
        nx = self.main_path(nx)
        if self.has_skip:
            nx = nx + _in
        return nx


class EfficientNet(nn.Module):
    def __init__(
        self, model_cfg, n_classes=10, input_size=3, se_ratio=0,
        scale=1, repeat=1, dropout=0.2
    ):
        super().__init__()

        # First 2.5D Conv + MBConv1
        prev_bw = new_bw = 32
        self.conv0 = CBA(
            model_cfg, input_size, new_bw, 3
        )
        new_bw = int(scale*16)
        self.mbconv1 = MobileConvBlock(
            model_cfg, prev_bw, new_bw, 1, widening_ratio=1, se_ratio=se_ratio
        )
        prev_bw = new_bw

        # Scaled MBConv blocks (Not the accurate MB0 values!!!)
        ksizes = [3, 3, 3, 3, 3]
        widths = [24, 60, 120, 200, 400]
        n_repeats = [1, 2, 3, 4, 1]
        strides = [1, 2, 1, 2, 1]

        mbconv6 = []

        for ksize, width, n_repeat, stride in zip(
            ksizes, widths, n_repeats, strides
        ):
            new_bw = int(scale*width)
            for i in range(round(n_repeat*repeat)):
                mbconv6.append(MobileConvBlock(
                    model_cfg, prev_bw, new_bw, ksize,
                    widening_ratio=6, se_ratio=se_ratio,
                    stride=(stride if i == 0 else 1)
                ))
                prev_bw = new_bw

        self.mbconv6 = nn.Sequential(*mbconv6)

        # Last Fc-layer
        self.dropfc = nn.Sequential(
            nn.Dropout(dropout),
            model_cfg.fc(prev_bw, n_classes)
        )

    def forward(self, nx):
        nx = self.mbconv1(self.conv0(nx))
        nx = self.mbconv6(nx)
        nx = F.adaptive_avg_pool2d(nx, 1).squeeze_(-1).squeeze_(-1)
        nx = self.dropfc(nx)
        return nx
