import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class bit2int(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mask): 
        mag = torch.sign(torch.relu(x))
        
        unit_mask = 2**torch.arange(len(mask)-1, -1, -1).to(x.device)
        unit_mask = unit_mask.view(len(mask), *([1]*(x.ndim-1)))
        
        magnitude = torch.sum(mag*unit_mask, dim=0)-2**(len(mask)-1)
        ctx.save_for_backward(torch.tensor(mask))
        return magnitude
        
    @staticmethod
    def backward(ctx, dy):
        mask, = ctx.saved_tensors
        re_dy = dy.repeat((len(mask), *([1]*dy.ndim)))
        unit_mask = 2**torch.arange(len(mask)-1, -1, -1).to(re_dy.device)
        unit_mask = unit_mask.view(len(mask), *([1]*(re_dy.ndim-1)))
        mask = mask.view(len(mask), *([1]*(re_dy.ndim-1))).to(re_dy.device)
        re_dy*= unit_mask

        return re_dy*mask, None

class BiF_Linear(nn.Module):
    def __init__(self, in_features, out_features, bias=False, train_mask=None):
        super(BiF_Linear, self).__init__()
        self.bit_num = len(train_mask)
        self.train_mask = train_mask
        self.in_features, self.out_features = in_features, out_features
        
        self.weight = nn.Parameter(torch.empty((self.bit_num, out_features, in_features)))
        self.alpha = torch.nn.Parameter(torch.tensor(0.005), requires_grad=False)

        self.weight_init()
        
    def weight_init(self):
        weight = torch.rand_like(self.weight[0])
        nn.init.kaiming_normal_(weight)

        self.alpha.data = weight.abs().max()/2**(self.bit_num-1)
        weight = weight / self.alpha
        weight = torch.floor(weight + torch.rand_like(weight)) + 2**(self.bit_num-1)

        for i in range(1, self.bit_num+1):
            self.weight.data[self.bit_num-i] = weight % 2
            weight=(weight/2).int()
        self.weight.data = 2*self.weight.data - 1
        
        shape = self.weight.shape
        a = torch.tensor((2/np.prod(shape[:-1]))**0.5) 
        self.weight.data = self.weight.data * a * torch.normal(0, 1, shape).abs()
    
    def reset_mask(self, train_mask):
        if isinstance(train_mask, str):
            train_mask=[int(i) for i in list(train_mask)]
        self.train_mask = train_mask

    def forward(self, input):
        weight = self.alpha * bit2int.apply(self.weight, self.train_mask)
        return F.linear(input, weight, None)

class BiF_Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, train_mask=None):
        super(BiF_Conv2d, self).__init__()
        self.bit_num = len(train_mask)
        self.train_mask = train_mask
        self.in_channels, self.out_channels = in_channels, out_channels
        self.kernel_size, self.stride, self.padding = kernel_size, stride, padding      

        self.weight = nn.Parameter(torch.empty((self.bit_num, out_channels, in_channels, kernel_size, kernel_size)))
        self.alpha = torch.nn.Parameter(torch.tensor(0.005), requires_grad=False)

        self.weight_init()

    def weight_init(self):
        weight = torch.rand_like(self.weight[0])
        nn.init.kaiming_normal_(weight)

        self.alpha.data = weight.abs().max()/2**(self.bit_num-1)
        weight = weight / self.alpha
        weight = torch.floor(weight + torch.rand_like(weight)) + 2**(self.bit_num-1)

        for i in range(1, self.bit_num+1):
            self.weight.data[self.bit_num-i] = weight % 2
            weight=(weight/2).int()
        self.weight.data = 2*self.weight.data - 1
        
        shape = self.weight.shape
        a = torch.tensor((2/np.prod(shape[:-1]))**0.5)
        self.weight.data = self.weight.data * a * torch.normal(0, 1, shape).abs()

    def reset_mask(self, train_mask):
        if isinstance(train_mask, str):
            train_mask=[int(i) for i in list(train_mask)]
        self.train_mask = train_mask

    def forward(self, input):
        weight = self.alpha * bit2int.apply(self.weight, self.train_mask)
        return F.conv2d(input, weight, None, self.stride, self.padding)

if __name__ == "__main__":
    print("done")
