# -*- coding: utf-8 -*-
"""RegNet in PyTorch.

Paper: "Designing Network Design Spaces".

Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


__all__ = ["regnet"]


class SE(nn.Module):
    """Squeeze-and-Excitation block."""

    def __init__(self, in_planes, se_planes):
        super(SE, self).__init__()
        self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True)
        self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True)

    def forward(self, x):
        out = F.adaptive_avg_pool2d(x, (1, 1))
        out = F.relu(self.se1(out))
        out = self.se2(out).sigmoid()
        out = x * out
        return out


class Block(nn.Module):
    def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio):
        super(Block, self).__init__()
        # 1x1
        w_b = int(round(w_out * bottleneck_ratio))
        self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w_b)
        # 3x3
        num_groups = w_b // group_width
        self.conv2 = nn.Conv2d(
            w_b,
            w_b,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=num_groups,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(w_b)
        # se
        self.with_se = se_ratio > 0
        if self.with_se:
            w_se = int(round(w_in * se_ratio))
            self.se = SE(w_b, w_se)
        # 1x1
        self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(w_out)

        self.shortcut = nn.Sequential()
        if stride != 1 or w_in != w_out:
            self.shortcut = nn.Sequential(
                nn.Conv2d(w_in, w_out, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(w_out),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        if self.with_se:
            out = self.se(out)
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class RegNet(nn.Module):
    def __init__(self, cfg, save_activations=False):
        super(RegNet, self).__init__()
        self.cfg = cfg
        self.dataset = cfg["dataset"]
        self.num_classes = self._decide_num_classes()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(0)
        self.layer2 = self._make_layer(1)
        self.layer3 = self._make_layer(2)
        self.layer4 = self._make_layer(3)
        self.linear = nn.Linear(self.cfg["widths"][-1], self.num_classes)

        # a placeholder for activations in the intermediate layers.
        self.save_activations = save_activations
        self.activations = None

    def _make_layer(self, idx):
        depth = self.cfg["depths"][idx]
        width = self.cfg["widths"][idx]
        stride = self.cfg["strides"][idx]
        group_width = self.cfg["group_width"]
        bottleneck_ratio = self.cfg["bottleneck_ratio"]
        se_ratio = self.cfg["se_ratio"]

        layers = []
        for i in range(depth):
            s = stride if i == 0 else 1
            layers.append(
                Block(self.in_planes, width, s, group_width, bottleneck_ratio, se_ratio)
            )
            self.in_planes = width
        return nn.Sequential(*layers)

    def _decide_num_classes(self):
        if self.dataset == "cifar10" or self.dataset == "svhn":
            return 10
        elif self.dataset == "cifar100":
            return 100
        elif "imagenet" in self.dataset:
            return 1000
        elif "femnist" == self.dataset:
            return 62

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        activation1 = out

        out = self.layer2(out)
        activation2 = out

        out = self.layer3(out)
        activation3 = out

        out = self.layer4(out)
        activation4 = out

        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)

        if self.save_activations:
            self.activations = [activation1, activation2, activation3, activation4]
        return out


def regnet_confs(net_name, dataset):
    cfgs = {
        "RegNetX_200MF": {
            "depths": [1, 1, 4, 7],
            "widths": [24, 56, 152, 368],
            "strides": [1, 1, 2, 2],
            "group_width": 8,
            "bottleneck_ratio": 1,
            "se_ratio": 0,
            "dataset": dataset,
        },
        "RegNetX_400MF": {
            "depths": [1, 2, 7, 12],
            "widths": [32, 64, 160, 384],
            "strides": [1, 1, 2, 2],
            "group_width": 16,
            "bottleneck_ratio": 1,
            "se_ratio": 0,
            "dataset": dataset,
        },
        "RegNetY_400MF": {
            "depths": [1, 2, 7, 12],
            "widths": [32, 64, 160, 384],
            "strides": [1, 1, 2, 2],
            "group_width": 16,
            "bottleneck_ratio": 1,
            "se_ratio": 0.25,
            "dataset": dataset,
        },
    }
    return RegNet(cfgs[net_name])


def regnet(conf, arch=None):
    dataset = conf.data

    if "cifar" in conf.data or "svhn" in conf.data:
        model = regnet_confs(conf.arch, dataset)
    else:
        raise NotImplementedError
    return model


if __name__ == "__main__":

    def get_n_model_params(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

    net = regnets(net_name="RegNetX_200MF", dataset="cifar10")
    print(f"The net has {get_n_model_params(net)} M.")
    x = torch.randn(1, 3, 32, 32)
    y = net(x)
    print(y.shape)
