"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import torch
from torch import nn
from torch.nn import functional as F

from examples.torch.object_detection.layers import DetectionOutput, PriorBox
from nncf.torch.dynamic_graph.context import no_nncf_trace


class SSDDetectionOutput(nn.Module):
    def __init__(self, num_input_features, num_classes, config):
        super().__init__()
        self.config = config
        self.num_classes = num_classes
        self.loss_inference = config.get('loss_inference', False)

        self.heads = nn.ModuleList()
        for i, num_features in enumerate(num_input_features):
            self.heads.append(SSDHead(
                num_features, num_classes, config.min_sizes[i], config.max_sizes[i],
                config.aspect_ratios[i], config.steps[i], config.variance, config.flip, config.clip
            ))

        self.detection_output = DetectionOutput(
            num_classes, 0, config.get('top_k', 200), config.get('keep_top_k', 200), 0.01, 0.45, 1, 1,
            "CENTER_SIZE", 0
        )

    def forward(self, source_features, img_tensor):
        locs = []
        confs = []
        priors = []
        for features, head in zip(source_features, self.heads):
            loc, conf, prior = head(features, img_tensor)
            locs.append(loc)
            confs.append(conf)
            priors.append(prior)

        batch = source_features[0].size(0)
        loc = torch.cat([o.view(batch, -1) for o in locs], 1)
        conf = torch.cat([o.view(batch, -1) for o in confs], 1)
        conf_softmax = F.softmax(conf.view(conf.size(0), -1, self.num_classes), dim=-1)

        with no_nncf_trace():
            priors = torch.cat(priors, dim=2)

        if self.training:
            return loc.view(batch, -1, 4), conf.view(batch, -1, self.num_classes), priors.view(1, 2, -1, 4)

        with no_nncf_trace():
            if self.loss_inference is True:
                return loc.view(batch, -1, 4), conf.view(batch, -1, self.num_classes), priors.view(1, 2, -1, 4)
            return self.detection_output(loc, conf_softmax.view(batch, -1), priors)


class SSDHead(nn.Module):
    def __init__(self, num_features, num_classes, min_size, max_size, aspect_ratios, steps, varience, flip, clip):
        super().__init__()
        self.num_classes = num_classes
        self.clip = clip
        self.flip = flip
        self.varience = varience
        self.steps = steps
        self.aspect_ratios = aspect_ratios
        self.max_size = max_size
        self.min_size = min_size
        self.input_features = num_features

        num_prior_per_cell = 2 + 2 * len(aspect_ratios)
        self.loc = nn.Conv2d(num_features, num_prior_per_cell * 4, kernel_size=3, padding=1)
        self.conf = nn.Conv2d(num_features, num_prior_per_cell * num_classes, kernel_size=3, padding=1)
        self.prior_box = PriorBox(min_size, max_size, aspect_ratios, flip, clip, varience, steps, 0.5, 0, 0,
                                  0, 0, 0)

    def forward(self, features, image_tensor):
        loc = self.loc(features)
        conf = self.conf(features)

        with no_nncf_trace():
            priors = self.prior_box(features, image_tensor).to(loc.device)
        # Knowledge Distillation Algo differentiates all model outputs with requires_grad=True.
        # Priors shouldn't be differentiated so they are explicitly excluded from backpropagation graph.

        priors = priors.detach()

        loc = loc.permute(0, 2, 3, 1).contiguous()
        conf = conf.permute(0, 2, 3, 1).contiguous()

        return loc, conf, priors


class MultiOutputSequential(nn.Sequential):
    def __init__(self, outputs, modules):
        super().__init__(*modules)
        self.outputs = [str(o) for o in outputs]

    #pylint:disable=W0237
    def forward(self, x):
        outputs = []
        for name, module in self._modules.items():
            x = module(x)
            if name in self.outputs:
                outputs.append(x)
        return outputs, x
