import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair

import math

from .functions import quantize
from .functions import binarize
from .functions import ternarize
from .functions import no_grad_mul
from .functions import round_back

import pdb

class Mod_conv2d_try(nn.Module):
    def __init__(self, a_bits, c_bits, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, group=1, bias=True, padding_mode='zeros', carry=False):
        super(Mod_conv2d_try, self).__init__()
        self.a_bits = a_bits
        self.c_bits = c_bits
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = group
        self.padding_mode = padding_mode
        self.weight = Parameter(torch.Tensor(
                self.out_channels, self.in_channels // self.groups, *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels))
        else:
            self.register_parameter('bias', None)

        ##low bit quantizer
        if self.c_bits == 0:
            self.q = binarize.apply
        elif self.c_bits == 1:
            self.q = ternarize.apply
        else:
            self.q = quantize.apply

        self.reset_parameters()


    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def reset_running_stats(self):
        self.running_mean.zero_()

    def qconv2d_forward(self, input, weight, ss, alpha):
        if self.c_bits > 1:
            min_val, max_val = weight.min(), weight.max()
            alpha = (max_val - min_val) / (2 ** self.c_bits - 1)
            qw = self.q(weight, self.c_bits, alpha, 0)
        elif self.c_bits < 0:
            alpha = torch.tensor(1.).cuda()
            ss = 1.
            qw = weight
        else:
            input_ = weight.view(weight.shape[0], -1)
            alpha = input_.abs().mean(-1).detach()
            qw = self.q(weight)
            qw = no_grad_mul.apply(qw, alpha.view(-1,1,1,1))
        q_out = F.conv2d(input/ss, qw/alpha.view(-1,1,1,1), None, self.stride, self.padding, self.dilation, self.groups)
        ofrate = ((q_out).abs() >= (2**7)).float().sum().item()/(q_out.abs() >= -1.0).float().sum().item() + 1e-10
        s_ = ss * alpha.view(1,-1, 1, 1)

        return q_out * s_, ofrate, s_

    def forward(self, input, scale, alpha):
        return self.qconv2d_forward(input, self.weight, scale, alpha)



if __name__ == '__main__':
    print("input grad")
