#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import math
from torch import nn
from torch.nn import init
from torch.nn.modules.utils import _pair

from ..functions.deform_conv2d_func import DeformConv2dFunction

class DeformConv2d(nn.Module):

    def __init__(self, in_channels, out_channels,
                 kernel_size, stride, padding, dilation=1, groups=1, deformable_groups=1, im2col_step=32, bias=True):
        super(DeformConv2d, self).__init__()

        if in_channels % groups != 0:
            raise ValueError('in_channels {} must be divisible by groups {}'.format(in_channels, groups))
        if out_channels % groups != 0:
            raise ValueError('out_channels {} must be divisible by groups {}'.format(out_channels, groups))

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups
        self.im2col_step = im2col_step
        self.use_bias = bias
        
        self.weight = nn.Parameter(torch.Tensor(
            out_channels, in_channels//groups, *self.kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()
        if not self.use_bias:
            self.bias.requires_grad = False
            self.bias.data.zero_()

    def reset_parameters(self):
        n = self.in_channels
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input, offset):
        assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
            offset.shape[1]
        return DeformConv2dFunction.apply(input, offset,
                                                   self.weight,
                                                   self.bias,
                                                   self.stride,
                                                   self.padding,
                                                   self.dilation,
                                                   self.groups,
                                                   self.deformable_groups,
                                                   self.im2col_step)

_DeformConv2d = DeformConv2dFunction.apply

class DeformConv2dPack(DeformConv2d):

    def __init__(self, in_channels, out_channels,
                 kernel_size, stride, padding,
                 dilation=1, groups=1, deformable_groups=1, im2col_step=32, bias=True, lr_mult=0.1):
        super(DeformConv2dPack, self).__init__(in_channels, out_channels,
                                  kernel_size, stride, padding, dilation, groups, deformable_groups, im2col_step, bias)

        out_channels = self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1]
        self.conv_offset = nn.Conv2d(self.in_channels,
                                          out_channels,
                                          kernel_size=self.kernel_size,
                                          stride=self.stride,
                                          padding=self.padding,
                                          bias=True)
        self.conv_offset.lr_mult = lr_mult
        self.conv_offset.inited = True
        self.init_offset()

    def init_offset(self):
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, input, return_offset=False):
        offset = self.conv_offset(input)
        bs = input.size()[0]
        im2col_step = bs // 2 if bs > 1 else 1
        # return DeformConv2dFunction.apply(input, offset,
        #                                   self.weight,
        #                                   self.bias,
        #                                   self.stride,
        #                                   self.padding,
        #                                   self.dilation,
        #                                   self.groups,
        #                                   self.deformable_groups,
        #                                   im2col_step)
        out = DeformConv2dFunction.apply(input, offset,
                                          self.weight,
                                          self.bias,
                                          self.stride,
                                          self.padding,
                                          self.dilation,
                                          self.groups,
                                          self.deformable_groups,
                                          im2col_step)
        if return_offset:
            return out, offset
        return out, None


class DeformConv2dPackMore(DeformConv2d):

    def __init__(self, in_channels, out_channels,
                 kernel_size, stride, padding,
                 dilation=1, groups=1, deformable_groups=1, im2col_step=64, bias=True, lr_mult=0.1):
        super(DeformConv2dPackMore, self).__init__(in_channels, out_channels,
                                                   kernel_size, stride, padding, dilation, groups, deformable_groups, im2col_step, bias)

        out_channels = self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1]
        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.in_channels, self.in_channels//4, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.in_channels//4),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.in_channels//4, out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=True)
        )
        self.conv_offset[-1].lr_mult = lr_mult
        self.conv_offset[-1].inited = True
        self.init_offset()

    def init_offset(self):
        self.conv_offset[-1].weight.data.zero_()
        self.conv_offset[-1].bias.data.zero_()

    def forward(self, input):
        offset = self.conv_offset(input)
        bs = input.size()[0]
        im2col_step = bs // 2
        return DeformConv2dFunction.apply(input, offset,
                                          self.weight,
                                          self.bias,
                                          self.stride,
                                          self.padding,
                                          self.dilation,
                                          self.groups,
                                          self.deformable_groups,
                                          im2col_step)
