import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair

import math
import numpy as np

from .functions import quantize
from .functions import binarize
from .functions import no_grad_mul
from .functions import round_back

import pdb

class Mod_conv2d_try(nn.Module):
    def __init__(self, a_bits, c_bits, in_channels, out_channels, out_w, kernel_size, stride=1,
                 padding=0, dilation=1, group=1, bias=True, padding_mode='zeros', carry=False):
        super(Mod_conv2d_try, self).__init__()
        self.a_bits = a_bits
        self.c_bits = c_bits
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = group
        self.padding_mode = padding_mode
        self.weight = Parameter(torch.Tensor(
                self.out_channels, self.in_channels // self.groups, *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels))
        else:
            self.register_parameter('bias', None)
        ## carry mean variables
        self.is_carry = carry
        self.register_buffer('running_mean', torch.zeros(out_channels, out_w, out_w))
        self.reset_parameters()
        self.reset_running_stats()


    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def reset_running_stats(self):
        self.running_mean.zero_()

    def qconv2d_forward(self, input, weight, ss, alpha, _cs):
        b = binarize.apply
        std_m = 0.0
        input_ = weight.view(weight.shape[0], -1)
        alpha = input_.abs().mean(-1).detach()
        qw = b(weight)
        qw = no_grad_mul.apply(qw, alpha.view(-1,1,1,1))
        q_out = F.conv2d(input/ss, qw/alpha.view(-1,1,1,1), None, self.stride, self.padding, self.dilation, self.groups)
        if _cs == 1:
            carry_es = self.carry_diff(input/ss)
            carry_ =  self.carry(input.detach()/ss, qw.detach()/alpha.view(-1,1,1,1))
            mc = self.batch_mean_carry(carry_)

            ## carry for bitpack
            std_m = torch.var(carry_es, dim = 0).mean()
            q_out = q_out + carry_ - mc 
        else:
            carry_ = 0

        ofrate = ((q_out).abs() >= (2**7)).float().sum().item()/((q_out.abs() >= -1.0).float().sum().item() + 1e-10)
        s_ = ss * alpha.view(1,-1, 1, 1)

        return q_out * s_, ofrate, s_, std_m

    def carry(self, x, qw, bit_width=2**8):
        b = binarize.apply

        mixedsignresult = F.conv2d(x, qw, None, self.stride,
                        self.padding, self.dilation, self.groups)
        mixedabsresult = F.conv2d(x, abs(qw), None, self.stride,
                        self.padding, self.dilation, self.groups)
        pos_result = (mixedsignresult + mixedabsresult) // 2
        neg_result = (mixedsignresult - mixedabsresult) // 2
        neg_count = F.conv2d(torch.clamp_min(torch.sign(x), 0),
                             torch.clamp_min(-torch.sign(qw), 0)
                             , None, self.stride, self.padding, self.dilation, self.groups) \
                    + F.conv2d(torch.clamp_min(-torch.sign(x), 0),
                             torch.clamp_min(torch.sign(qw), 0)
                             , None, self.stride, self.padding, self.dilation, self.groups)
        unsignedrep = pos_result + neg_result + neg_count * bit_width

        carry = unsignedrep // bit_width
        total_carry = carry
        remain = unsignedrep % bit_width
        while carry.sum() != 0:
            remain += carry
            carry = remain // bit_width
            remain = remain % bit_width
            total_carry += carry
        return total_carry

    def batch_mean_carry(self, input):
        exponential_average_factor = 0.0

        if self.training:
            exponential_average_factor = 0.1

        if self.training:
            mean = input.mean([0])
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
            if self.running_mean.sum() == 0:
                mean = mean
            else:
                mean = self.running_mean
        else:
            mean = self.running_mean

        return mean

    def carry_diff(self, x, alpha=1, bit_width=2**8, kx=2, kw=2):
        b = binarize.apply
        qw = b(self.weight)

        mixedsignresult = F.conv2d(x, qw, None, self.stride,
                        self.padding, self.dilation, self.groups)
        mixedabsresult = F.conv2d(x, abs(qw), None, self.stride,
                        self.padding, self.dilation, self.groups)
        pos_result = mixedsignresult + mixedabsresult
        neg_result = mixedsignresult - mixedabsresult
        pos_result = pos_result / 2 - ((pos_result %2)/2).detach()
        neg_result = neg_result / 2- ((neg_result %2)/2).detach()
        neg_count = F.conv2d(torch.clamp_min(torch.tanh(x*kx), 0),
                             torch.clamp_min(-torch.tanh(qw*kw), 0)
                             , None, self.stride, self.padding, self.dilation, self.groups) \
                    + F.conv2d(torch.clamp_min(-torch.tanh(x*kx), 0),
                             torch.clamp_min(torch.tanh(qw*kw), 0)
                             , None, self.stride, self.padding, self.dilation, self.groups)
        unsignedrep = pos_result + neg_result + neg_count * bit_width

        carry = unsignedrep / bit_width - ((unsignedrep%bit_width)/bit_width).detach()
        total_carry = carry
        remain = unsignedrep % bit_width
        while carry.sum() > 0:
            remain += carry
            carry = remain / bit_width - ((remain%bit_width)/bit_width).detach()
            remain = remain % bit_width
            total_carry += carry
        return total_carry*alpha

    def forward(self, input, scale, alpha, _cs):
        return self.qconv2d_forward(input, self.weight, scale, alpha, _cs)



if __name__ == '__main__':
    conv = Mod_conv2d_try(4, 1, 128, 128, 28, kernel_size=3, padding=0,bias=False, carry=True)
    x1 = torch.randn(4, 2, 2, 2)
    x2 = torch.randn(4, 2, 2, 2)
    x3 = torch.randn(4, 2, 2, 2)
    y1 = conv.batch_mean_carry(x1)
    y2 = conv.batch_mean_carry(x2)
    y3 = conv.batch_mean_carry(x3)
    print(y3)
