"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import torch
from torch import nn

from examples.torch.common import restricted_pickle_module
from examples.torch.common.example_logger import logger
from examples.torch.object_detection.layers.modules.ssd_head import MultiOutputSequential, SSDDetectionOutput
from nncf.torch.checkpoint_loading import load_state


def conv_bn(inp, oup, kernel, stride, padding):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel, stride, padding, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )


def conv_dw(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU(inplace=True),

        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True),
    )


def mobilenet(start_input_channels=3):
    model = MultiOutputSequential(
        [11, 13],
        [
            conv_bn(start_input_channels, 32, 3, 2, 1),
            conv_dw(32, 64, 1),
            conv_dw(64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1)
        ]
    )
    return model


def extra_layers(start_input_channels):
    return MultiOutputSequential(
        [1, 3, 5, 7],
        [
            conv_bn(start_input_channels, 256, 1, 1, 0),
            conv_bn(256, 512, 3, 2, 1),
            conv_bn(512, 128, 1, 1, 0),
            conv_bn(128, 256, 3, 2, 1),
            conv_bn(256, 128, 1, 1, 0),
            conv_bn(128, 256, 3, 2, 1),
            conv_bn(256, 64, 1, 1, 0),
            conv_bn(64, 128, 3, 2, 1)
        ]
    )


class MobileNetSSD(nn.Module):
    def __init__(self, num_classes, cfg):
        super().__init__()
        self.cfg = cfg
        self.num_classes = num_classes

        self.basenet = mobilenet()
        self.extras = extra_layers(1024)

        NUM_INPUT_FEATURES = [512, 1024, 512, 256, 256, 128]
        self.detection_head = SSDDetectionOutput(NUM_INPUT_FEATURES, num_classes, cfg)

    def forward(self, x):
        img_tensor = x[0].clone().unsqueeze(0)

        sources, x = self.basenet(x)
        extra_sources, x = self.extras(x)

        return self.detection_head(sources + extra_sources, img_tensor)


def build_ssd_mobilenet(cfg, size, num_classes, config):
    if size != 300:
        raise ValueError("Only Mobilenet-SSD with input size 300 is supported")
    mobilenet_ssd = MobileNetSSD(num_classes, cfg)

    if config.basenet and (config.resuming_checkpoint_path is None) and (config.weights is None):
        logger.debug('Loading base network...')
        #
        # ** WARNING: torch.load functionality uses Python's pickling facilities that
        # may be used to perform arbitrary code execution during unpickling. Only load the data you
        # trust.
        #
        basenet_weights = torch.load(config.basenet,
                                     pickle_module=restricted_pickle_module)['state_dict']
        new_weights = {}
        for wn, wv in basenet_weights.items():
            wn = wn.replace('model.', '')
            new_weights[wn] = wv

        load_state(mobilenet_ssd.basenet, new_weights, is_resume=False)
    return mobilenet_ssd
