import torch
from DECLinear import DECLinear

from plugin import *
import dp_ext

class LUTGEMMLinear(DECLinear):
    def __init__(self, in_features, out_features, bitwidth, group_size, bias=False, dtype=torch.half):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.bitwidth = bitwidth
        self.group_size = group_size
        self.dtype = dtype

        self.register_buffer(
            'qweight',
            torch.empty((in_features//32, bitwidth, out_features), dtype=torch.int32, device='cuda')
        )

        self.register_buffer(
            'alpha',
            torch.empty((in_features // group_size, bitwidth, out_features), dtype=self.dtype, device='cuda')
        )
        
        self.register_buffer(
            'q_bias',
            torch.empty((in_features // group_size, out_features), dtype=self.dtype, device='cuda')
        )
       
        if bias:
            self.register_buffer(
                "bias",
                torch.empty((out_features,), dtype=self.dtype, device='cuda')
            )
        else:
            self.bias = None

        self.output = torch.zeros((1, 1, out_features), dtype=self.dtype, device='cuda')

    def forward(self, x, **kwargs):

        assert(x.shape[0] == 1)
        assert(x.shape[1] == 1)

        # clear the output
        self.output.zero_()

        dec_lutgemm(self.dec_config, x, self.qweight, self.alpha, self.q_bias, self.output, self.bitwidth, self.group_size)
        
        if self.bias is not None:
            self.output += self.bias

        return self.output
