import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair

import math
import pdb

from .functions import quantize
from .functions import binarize
from .functions import no_grad_mul
from .functions import round_back

class Mod_Linear(nn.Module):
    def __init__(self, a_bits, c_bits, in_features, out_features, bias=True):
        super(Mod_Linear, self).__init__()
        self.a_bits = a_bits
        self.c_bits = c_bits
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(self.out_features, self.in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input, s, alpha):
        b = binarize.apply
        rb = round_back.apply
        alpha = self.weight.abs().mean(-1)
        qw = b(self.weight)
        qw = no_grad_mul.apply(qw, alpha.view(-1, 1))
        q_out = F.linear(input/s, qw/alpha.view(-1, 1), self.bias)
        of =  (q_out.abs() >= 2**11).float().sum().item()/(q_out.abs() >= -1.0).float().sum().item() + 1e-5
        s_ = s * alpha.view(1, -1)
        return q_out * s_, of, s_ 


if __name__ == '__main__':
    print("output")
