import torch
from DECLinear import DECLinear

from plugin import *
import dp_ext

class APLinearSel(DECLinear):
    def __init__(self, in_features, out_features, bitwidth, bias=False, dtype=torch.half):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.bitwidth = bitwidth
        self.dtype = dtype

        self.register_buffer(
            'qweight',
            torch.empty((bitwidth, out_features,  in_features // 32), dtype=torch.int32, device='cuda')
        )

        self.register_buffer(
            'lut3',
            torch.empty((out_features, 2 ** 3), dtype=self.dtype, device='cuda')
        )

        self.register_buffer(
            'lut4',
            torch.empty((out_features, 2 ** 4), dtype=self.dtype, device='cuda')
        )

        self.register_buffer(
            'lut5',
            torch.empty((out_features, 2 ** 5), dtype=self.dtype, device='cuda')
        )

        self.register_buffer(
            'lut6',
            torch.empty((out_features, 2 ** 6), dtype=self.dtype, device='cuda')
        )
       
        if bias:
            self.register_buffer(
                "bias",
                torch.empty((out_features,), dtype=self.dtype, device='cuda')
            )
        else:
            self.bias = None

        self.output = torch.zeros((1, 1, self.out_features), dtype=self.dtype, device='cuda')

    def forward(self, x, bsel, sne, **kwargs):

        assert(x.shape[0] == 1)
        assert(x.shape[1] == 1)

        anyprec_gemv_sel(x, self.qweight, self.lut3, self.lut4, self.lut5, self.lut6, self.output, self.bitwidth, bsel, sne)

        if self.bias is not None:
            self.output += self.bias

        return self.output