"""
Quantize Segmentation for detr
"""
import io
from collections import defaultdict
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
from torch import Tensor
import logging
from ..quantization_utils.quant_modules import *
import pdb


def _expand(tensor, length: int):
    return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)



class GroupNorm(Module):
    r"""Applies Group Normalization over a mini-batch of inputs as described in
    the paper `Group Normalization`_ .

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The input channels are separated into :attr:`num_groups` groups, each containing
    ``num_channels / num_groups`` channels. The mean and standard-deviation are calculated
    separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
    per-channel affine transform parameter vectors of size :attr:`num_channels` if
    :attr:`affine` is ``True``.

    This layer uses statistics computed from input data in both training and
    evaluation modes.

    Args:
        num_groups (int): number of groups to separate the channels into
        num_channels (int): number of channels expected in input
        eps: a value added to the denominator for numerical stability. Default: 1e-5
        affine: a boolean value that when set to ``True``, this module
            has learnable per-channel affine parameters initialized to ones (for weights)
            and zeros (for biases). Default: ``True``.

    Shape:
        - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
        - Output: :math:`(N, C, *)` (same shape as input)

    Examples::

        >>> input = torch.randn(20, 6, 10, 10)
        >>> # Separate 6 channels into 3 groups
        >>> m = nn.GroupNorm(3, 6)
        >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
        >>> m = nn.GroupNorm(6, 6)
        >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
        >>> m = nn.GroupNorm(1, 6)
        >>> # Activating the module
        >>> output = m(input)

    .. _`Group Normalization`: https://arxiv.org/abs/1803.08494
    """
    __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']

    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
        super(GroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_channels))
            self.bias = Parameter(torch.Tensor(num_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def forward(self, input):
        return F.group_norm(
            input, self.num_groups, self.weight, self.bias, self.eps)

    def extra_repr(self):
        return '{num_groups}, {num_channels}, eps={eps}, ' \
            'affine={affine}'.format(**self.__dict__)


def _verify_batch_size(size):
    # type: (List[int]) -> None
    # XXX: JIT script does not support the reduce from functools, and mul op is a
    # builtin, which cannot be used as a value to a func yet, so rewrite this size
    # check to a simple equivalent for loop
    #
    # TODO: make use of reduce like below when JIT is ready with the missing features:
    # from operator import mul
    # from functools import reduce
    #
    #   if reduce(mul, size[2:], size[0]) == 1
    size_prods = size[0]
    for i in range(len(size) - 2):
        size_prods *= size[i + 2]
    if size_prods == 1:
        raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))



def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
    # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
    r"""Applies Group Normalization for last certain number of dimensions.

    See :class:`~torch.nn.GroupNorm` for details.
    """
    if not torch.jit.is_scripting():
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(
                group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps)
    _verify_batch_size([
        input.size(0) * input.size(1) // num_groups, num_groups]
        + list(input.size()[2:]))
    return torch.group_norm(input, num_groups, weight, bias, eps,
                            torch.backends.cudnn.enabled)



class Q_GroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
        super(Q_GroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_channels))
            self.bias = Parameter(torch.Tensor(num_channels))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()

    def set_param(self, gn):
        self.num_groups = gn.num_groups
        self.num_channels = gn.num_channels
        self.eps = gn.eps
        self.affine = gn.affine
        if self.affine:
            self.weight = Parameter(gn.weight.data.clone())
            self.bias = Parameter(gn.bias.data.clone())
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def forward(self, input):
        return F.group_norm(
            input, self.num_groups, self.weight, self.bias, self.eps)


class Q_bbox_atten(nn.Module):
    def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim

        self.quant_act_in_q = QuantAct()
        self.quant_act_in_k = QuantAct()

        # dropout = getattr(bbox_attention, 'dropout')
        # q_linear = getattr(bbox_attention, 'q_linear')
        # k_linear = getattr(bbox_attention, 'k_linear')

        self.quant_dropout = QuantDropout(dropout)
        # self.quant_dropout.set_param(dropout)

        self.quant_q_linear = QuantLinear()
        self.quant_q_linear.set_param(
            nn.Linear(query_dim, hidden_dim, bias=bias))

        self.quant_conv = QuantConv2d()

        self.quant_k_linear = QuantLinear()
        self.quant_k_linear.set_param(
            nn.Linear(query_dim, hidden_dim, bias=bias))

        # self.quant_act_out = QuantAct()
        nn.init.zeros_(self.quant_k_linear.bias)
        nn.init.zeros_(self.quant_q_linear.bias)
        nn.init.xavier_uniform_(self.quant_k_linear.weight)
        nn.init.xavier_uniform_(self.quant_q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k, mask: Optional[Tensor] = None):
        k, act_scaling_factor = self.quant_act_in_k(k)

        q, act_scaling_factor = self.quant_act_in_q(q)
        q = self.quant_q_linear(q, act_scaling_factor)

        k = F.conv2d(
            k, self.quant_k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.quant_k_linear.bias)

        # self.quant_conv.set_param(k)
        # k = self.quant_conv(k)

        qh = q.view(q.shape[0], q.shape[1], self.num_heads,
                    self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim //
                    self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw",
                               qh * self.normalize_fact, kh)

        if mask is not None:
            weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
        weights, act_scaling_factor = self.quant_dropout(weights)

        # weights, act_scaling_factor = self.quant_act_out(weights)
        return weights


class Q_mask_head(nn.Module):
    def __init__(self, mask_head):
        super().__init__()
        self.quant_act_in = QuantAct()

        lay1 = getattr(mask_head, 'lay1')
        self.quant_lay1 = QuantConv2d()
        self.quant_lay1.set_param(lay1)
        gn1 = getattr(mask_head, 'gn1')
        self.quant_gn1 = gn1

        self.quant_act1 = QuantAct()
        lay2 = getattr(mask_head, 'lay2')
        self.quant_lay2 = QuantConv2d()
        self.quant_lay2.set_param(lay2)
        gn2 = getattr(mask_head, 'gn2')
        self.quant_gn2 = gn2

        self.quant_act2 = QuantAct()
        lay3 = getattr(mask_head, 'lay3')
        self.quant_lay3 = QuantConv2d()
        self.quant_lay3.set_param(lay3)
        gn3 = getattr(mask_head, 'gn3')
        self.quant_gn3 = gn3

        self.quant_act3 = QuantAct()
        lay4 = getattr(mask_head, 'lay4')
        self.quant_lay4 = QuantConv2d()
        self.quant_lay4.set_param(lay4)
        gn4 = getattr(mask_head, 'gn4')
        self.quant_gn4 = gn4

        self.quant_act4 = QuantAct()
        lay5 = getattr(mask_head, 'lay5')
        self.quant_lay5 = QuantConv2d()
        self.quant_lay5.set_param(lay5)
        gn5 = getattr(mask_head, 'gn5')
        self.quant_gn5 = gn5

        self.quant_act_out = QuantAct()
        out_lay = getattr(mask_head, 'out_lay')
        self.quant_out_lay = QuantConv2d()
        self.quant_out_lay.set_param(out_lay)

        self.quant_act5 = QuantAct()
        adapter1 = getattr(mask_head, 'adapter1')
        self.quant_adapter1 = QuantConv2d()
        self.quant_adapter1.set_param(adapter1)

        self.quant_act6 = QuantAct()
        adapter2 = getattr(mask_head, 'adapter2')
        self.quant_adapter2 = QuantConv2d()
        self.quant_adapter2.set_param(adapter2)

        self.quant_act7 = QuantAct()
        adapter3 = getattr(mask_head, 'adapter3')
        self.quant_adapter3 = QuantConv2d()
        self.quant_adapter3.set_param(adapter3)

    def forward(self, boxatt_out: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
        x = torch.cat([_expand(boxatt_out, bbox_mask.shape[1]),
                      bbox_mask.flatten(0, 1)], 1)

        x, act_scaling_factor = self.quant_act_in(x)
        x, act_scaling_factor = self.quant_lay1(x, act_scaling_factor)
        x = self.quant_gn1(x)
        x = F.relu(x)

        x, act_scaling_factor = self.quant_act1(x)
        x, act_scaling_factor = self.quant_lay2(x, act_scaling_factor)
        x = self.quant_gn2(x)
        x = F.relu(x)

        x, act_scaling_factor = self.quant_act5(x)
        cur_fpn, act_scaling_factor = self.quant_adapter1(
            fpns[0], act_scaling_factor)
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")

        x, act_scaling_factor = self.quant_act2(x)
        x, act_scaling_factor = self.quant_lay3(x, act_scaling_factor)
        x = self.quant_gn3(x)
        x = F.relu(x)

        x, act_scaling_factor = self.quant_act6(x)
        cur_fpn, act_scaling_factor = self.quant_adapter2(
            fpns[1], act_scaling_factor)
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")

        x, act_scaling_factor = self.quant_act3(x)
        x, act_scaling_factor = self.quant_lay4(x, act_scaling_factor)
        x = self.quant_gn4(x)
        x = F.relu(x)

        x, act_scaling_factor = self.quant_act7(x)
        cur_fpn, act_scaling_factor = self.quant_adapter3(
            fpns[2], act_scaling_factor)
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")

        x, act_scaling_factor = self.quant_act4(x)
        x, act_scaling_factor = self.quant_lay5(x, act_scaling_factor)
        x = self.quant_gn5(x)
        x = F.relu(x)

        x, act_scaling_factor = self.quant_act_out(x)
        x, act_scaling_factor = self.quant_out_lay(x, act_scaling_factor)

        return x
