import math
import torch
import pickle
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from math import sqrt
from Layers.common import channel_shuffle, round_channels, round_groups, initialize_conv
from inspect import isfunction


class GroupedLinear(nn.Module):
    def __init__(self, in_features, out_features, groups):
        super(GroupedLinear, self).__init__()
        assert in_features % groups == 0
        assert out_features % groups == 0
        self.in_features = in_features
        self.out_features = out_features
        self.groups = groups
        self.weight = torch.nn.Parameter(torch.empty(self.groups, self.in_features // self.groups, self.out_features // self.groups))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, input):
        batch, in_features, height, width = input.shape
        output = input.view(batch, self.groups, in_features // self.groups, height, width)
        output = torch.einsum('bgihw,gio->bgohw', output, self.weight)
        output.reshape(batch, self.out_features, height, width)
        return output


class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__(in_features, out_features, bias)
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def forward(self, input):
        if input.shape[1] != self.in_features:
            device = self.weight.device
            self.__init__(input.shape[1], self.out_features)
            self.to(device)

        W = self.weight_mask * self.weight

        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return F.linear(input, W, b)


class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def _conv_forward(self, input, weight, bias):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation,
                            self.groups)
        return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

    def forward(self, input):
        if input.shape[1] > self.in_channels:
            input = input[:, :self.in_channels, :, :]
        elif input.shape[1] < self.in_channels:
            print(f'WARINING: unusual case. Reinit {self} in_channels {self.in_channels} to {input.shape[1]}')
            device = self.weight.device
            self.__init__(input.shape[1], self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias,
                          self.padding_mode)
            self.to(device)

        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        x = self._conv_forward(input, W, b)
        return x

    def set_group_number(self, groups):
        self.__init__(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, groups, self.bias,
                      self.padding_mode)
        if groups > 1:
            self.groups = groups


class BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(BatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
        if self.affine:
            self.register_buffer('weight_mask', torch.ones(self.weight.shape))
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that if gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
        if self.affine:
            W = self.weight_mask * self.weight
            b = self.bias_mask * self.bias
        else:
            W = self.weight
            b = self.bias
        # W = W.to(dtype=input.dtype)
        # b = b.to(dtype=input.dtype)
        # self.running_mean = self.running_mean.to(dtype=input.dtype)
        # self.running_var = self.running_var.to(dtype=input.dtype)

        return F.batch_norm(input, self.running_mean, self.running_var, W, b, self.training or not self.track_running_stats,
                            exponential_average_factor, self.eps)

    def expand_channel(self, ratio, divisor=8):
        num_features = round_channels(self.num_features * ratio, divisor)
        super(BatchNorm2d, self).__init__(num_features, self.eps, self.momentum, self.affine, self.track_running_stats)
        if self.affine:
            self.register_buffer('weight_mask', torch.ones(self.weight.shape))
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))


class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, input):
        return channel_shuffle(input, self.groups)

    def __repr__(self):
        return self.__class__.__name__ + f'(groups={self.groups})'


def conv1x1_block(in_channels, out_channels, stride=1, groups=1, bias=False, activation=(lambda: nn.ReLU(inplace=True)), shuffle=False):
    return ConvBlock(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=1,
                     stride=stride,
                     groups=groups,
                     bias=bias,
                     activation=activation,
                     shuffle=shuffle)


def conv3x3_block(in_channels, out_channels, stride=1, groups=1, padding=1, activation=(lambda: nn.ReLU(inplace=True)), shuffle=False):
    return ConvBlock(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=3,
                     stride=stride,
                     padding=padding,
                     groups=groups,
                     activation=activation,
                     shuffle=shuffle)


def conv7x7_block(in_channels, out_channels, stride=1, groups=1, padding=1, activation=(lambda: nn.ReLU(inplace=True)), shuffle=False):
    return ConvBlock(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=7,
                     stride=stride,
                     padding=padding,
                     groups=groups,
                     activation=activation,
                     shuffle=shuffle)


def dwconv5x5_block(in_channels, out_channels, stride=1, padding=1, activation=(lambda: nn.ReLU(inplace=True)), shuffle=False):
    return ConvBlock(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=5,
                     stride=stride,
                     padding=padding,
                     groups=out_channels,
                     activation=activation,
                     shuffle=shuffle)


def dwconv3x3_block(in_channels, out_channels, stride=1, padding=1, activation=(lambda: nn.ReLU(inplace=True)), shuffle=False):
    return ConvBlock(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=3,
                     stride=stride,
                     padding=padding,
                     groups=out_channels,
                     activation=activation,
                     shuffle=shuffle)


def depthwise_conv3x3(channels, stride=1, padding=1, bias=False):
    return Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=stride, padding=padding, groups=channels, bias=bias)


def conv1x1(in_channels, out_channels, stride=1, groups=1, bias=False):
    return Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, groups=groups, bias=bias)


def conv3x3(in_channels, out_channels, stride=1, groups=1, bias=False):
    return Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, groups=groups, bias=bias)


class ConvBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 groups=1,
                 padding=0,
                 bias=False,
                 activation=(lambda: nn.ReLU(inplace=True)),
                 shuffle=False):
        super(ConvBlock, self).__init__()
        self.activate = (activation is not None)

        self.conv = Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride=stride,
                           groups=groups,
                           padding=padding,
                           bias=bias)
        self.bn = BatchNorm2d(num_features=out_channels)
        if self.activate:
            self.activ = get_activation(activation)
        self.shuffle = shuffle

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.activate:
            x = self.activ(x)
        if self.shuffle and self.conv.groups > 1:
            x = channel_shuffle(x, self.conv.groups)
        return x

    def __repr__(self):
        repr = super(ConvBlock, self).__repr__()
        if self.conv.groups > 1 and self.shuffle:
            repr += f'\n Channel Shuffle groups={self.conv.groups}'
        return repr


class SEBlock(nn.Module):
    def __init__(self,
                 channels,
                 reduction=16,
                 mid_channels=None,
                 round_mid=False,
                 use_conv=True,
                 mid_activation=(lambda: nn.ReLU(inplace=True)),
                 out_activation=(lambda: nn.Sigmoid())):
        super(SEBlock, self).__init__()
        self.reduction = reduction
        self.use_conv = use_conv
        if mid_channels is None:
            mid_channels = channels // reduction if not round_mid else round_channels(float(channels) / reduction)

        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        if use_conv:
            self.conv1 = conv1x1(in_channels=channels, out_channels=mid_channels, bias=True)
        else:
            self.fc1 = Linear(in_features=channels, out_features=mid_channels)
        self.activ = get_activation(mid_activation)
        if use_conv:
            self.conv2 = conv1x1(in_channels=mid_channels, out_channels=channels, bias=True)
        else:
            self.fc2 = Linear(in_features=mid_channels, out_features=channels)
        self.sigmoid = get_activation(out_activation)

    def forward(self, x):
        assert not self.use_conv
        if x.shape[1] != self.fc1.in_features:
            device = self.fc1.weight.device
            self.__init__(channels=x.shape[1], reduction=self.reduction, mid_activation=self.activ, use_conv=False)
            self.to(device)

        w = self.pool(x)
        if not self.use_conv:
            w = w.view(x.size(0), -1)
        w = self.conv1(w) if self.use_conv else self.fc1(w)
        w = self.activ(w)
        w = self.conv2(w) if self.use_conv else self.fc2(w)
        w = self.sigmoid(w)
        if not self.use_conv:
            w = w.unsqueeze(2).unsqueeze(3)
        x = x * w
        return x


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class IdentityAdd(nn.Module):
    def forward(self, x, identity):
        return x + identity


class PaddedIdentityAdd(nn.Module):
    def __init__(self):
        super(PaddedIdentityAdd, self).__init__()
        self.out_channels = None
        self.perm = None

    def forward(self, x, identity):
        x_channels = x.shape[1]
        identity_channels = identity.shape[1]
        if x_channels != identity_channels:
            if x_channels > identity_channels:
                x, identity = identity, x
                x_channels, identity_channels = identity_channels, x_channels
            self.out_channels = identity_channels
            #     self.perm = torch.randperm(self.out_channels)
            # return torch.cat((identity[:, self.perm < x_channels, :, :] + x,
            #                   identity[:, self.perm >= x_channels, :, :]), 1)
            return torch.cat((identity[:, :x_channels, :, :] + x, identity[:, x_channels:, :, :]), 1)
        else:
            self.out_channels = identity_channels
            return x + identity

    def __repr__(self):
        return f'PaddedIdentityAdd(out_channels = {self.out_channels})'


def get_activation(activation):
    if isfunction(activation):
        return activation()
    elif isinstance(activation, nn.Module):
        return activation
    elif activation == 'relu6':
        return nn.ReLU6(inplace=True)
    elif activation == 'relu':
        return nn.ReLU(inplace=True)
    elif activation == "swish":
        return Swish()
    else:
        raise NotImplementedError


def replace_activation(model, activate):
    for name, module in model.named_modules():
        if isinstance(module, ConvBlock) and module.activate:
            module.activ = get_activation(activate)


def sparse_initialize(model):
    for name, m in model.named_modules():
        if isinstance(m, Conv2d):
            initialize_conv(m.weight, m.weight_mask, sqrt_sparsity=False)


def channel_prune(model, pruned_in_channels=None, pruned_out_channels=None, prune_info=None):
    if prune_info is None:
        prune_info = {}
    if pruned_in_channels is None:
        pruned_in_channels = {}
    if pruned_out_channels is None:
        pruned_out_channels = {}

    prev_channels = None
    for name, m in model.named_modules():
        if isinstance(m, ConvBlock):
            # for pointwise convolution
            if (m in pruned_in_channels and m in pruned_out_channels) or name in prune_info:
                if name in prune_info:
                    pruned_in = prune_info[name][0]
                    pruned_out = prune_info[name][1]
                else:
                    pruned_in = pruned_in_channels[m]
                    pruned_out = pruned_out_channels[m]
                    # for store&load purpose, a little hacky, may need refactor later
                    prune_info[name] = [pruned_in, pruned_out]

                in_channels = m.conv.in_channels - pruned_in
                out_channels = m.conv.out_channels - pruned_out
                activation = None if not hasattr(m, 'activ') else m.activ
                m.__init__(in_channels, out_channels, m.conv.kernel_size, m.conv.stride, m.conv.groups, m.conv.padding, m.conv.bias, activation,
                           m.shuffle)
                prev_channels = out_channels
            # for depthwise convolution
            elif m.conv.kernel_size == (3, 3) or m.conv.kernel_size == (5, 5):
                if prev_channels is None:
                    continue
                activation = None if not hasattr(m, 'activ') else m.activ
                m.__init__(prev_channels, prev_channels, m.conv.kernel_size, m.conv.stride, prev_channels, m.conv.padding, m.conv.bias, activation,
                           m.shuffle)
    pickle.dump(prune_info, open('prune_info.pkl', 'wb'))
    # print(prune_info)
    model.clean()
    # print('pruned finish')


def fake_channel_prune(model, pruned_in_channels, pruned_out_channels):
    prev_channels = None
    real_sp = []
    for name, m in model.named_modules():
        if isinstance(m, ConvBlock):
            # for pointwise convolution
            if m.conv.kernel_size == (1, 1):
                assert m in pruned_in_channels and m in pruned_out_channels
                in_channels = m.conv.in_channels - pruned_in_channels[m]
                out_channels = m.conv.out_channels - pruned_out_channels[m]
                m.conv.weight_mask[out_channels:, :, :, :] = 0
                m.conv.weight_mask[:, in_channels:, :, :] = 0
                prev_channels = out_channels
                real_sp.append(m.conv.weight_mask.sum().item() / m.conv.weight_mask.numel())
            # for depthwise convolution
            elif m.conv.kernel_size == (3, 3) and prev_channels is not None:
                # m.conv.weight_mask[:, prev_channels:, :, :] = 0
                # m.conv.weight_mask[prev_channels:, :, :, :] = 0
                pass
    # print(real_sp)
    # print('pruned finish')
