import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair

import math

from functions import quantize

class q_Linear(nn.Module):
    def __init__(self, bits, in_features, out_features, bias=True):
        super(q_Linear, self).__init__()
        self.bits = bits
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(self.out_features, self.in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        q = quantize.apply
        qw = q(self.weight, self.bits)
        return F.linear(input, qw, self.bias)


    def carry(self, x, alpha=1, bit_width=(2**8)):
        q = quantize.apply
        qw = q(self.weight, self.bits, 1)
        if ((qw != torch.tensor(0))
                & (qw != torch.tensor(1))
                & (qw != torch.tensor(-1))).sum() > 0:
            raise ValueError('quantized weight has to be ternary for the carry operation')
        if (x<0).sum()>0:
            raise ValueError('activation is assumed to be positive')
        mixedsignresult = F.linear(x, qw)
        mixedabsresult = F.linear(x, abs(qw))
        pos_result = (mixedsignresult + mixedabsresult) // 2
        neg_result = (mixedsignresult - mixedabsresult) // 2
        neg_count = F.linear(torch.clamp_min(torch.sign(x), 0), torch.clamp_min(-torch.sign(qw), 0)) \
                    + F.linear(torch.clamp_min(-torch.sign(x), 0), torch.clamp_min(torch.sign(qw), 0))

        unsignedrep = pos_result + neg_result + neg_count * bit_width

        carry = unsignedrep // bit_width
        total_carry = carry
        remain = unsignedrep % bit_width
        while carry.sum() != 0:
            # print(carry, remain)
            remain += carry
            carry = remain // bit_width
            remain = remain % bit_width
            total_carry += carry
        return total_carry*alpha


if __name__ == '__main__':
    # #x = Variable(torch.tensor([[-0.5295,  0.1781],[ 0.2916, -1.7480]], requires_grad=True))
    # x = torch.randn(4, 4, requires_grad = True)
    # print("input", x)
    # #q_2 = q_act(2)
    # #q_4 = q_act(4)
    # q_8 = q_act()
    # relu = nn.ReLU()
    # q_x = q_8(x)
    # #q_x2 = q_2(x.detach())
    # #q_x8 = q_8(x.detach())
    # #print("quantized 4 input", q_x)
    # #print("quantized 2 input", q_x2)
    # print("quantized 32 input", q_x)
    # output = relu(q_x)
    # q_x = q_8(output)
    # output = relu(q_x)
    # q_x = q_8(output)
    # output = relu(q_x)
    # q_x = q_8(output)
    # output = relu(q_x)
    # q_x = q_8(output)
    # output = relu(q_x)
    # print("output", output)
    # output.sum().backward()
    # print("input grad", x.grad)

    #Testing the carry operation
    x = torch.tensor([[[3.,3, 3, 3., 3, 3, 3, 3., 3]]])
    layer = q_Linear(3, 9, 1)

    layer.weight.data = torch.tensor([[1,1,1., 1,1,1,1,1,1]])

    print(layer.carry(x)[0,0,0].item()==3)
    x = torch.tensor([[[0.,0, 0, 0., 0, 0, 0, 0., 0]]])
<<<<<<< HEAD
    print(layer.carry(x)[0,0,0].item()==0)
=======
    print(layer.carry(x)[0,0,0].item()==0) 
>>>>>>> bitpack_v1
