class MulQRQuant(Function):
    @staticmethod
    def forward(ctx, x, dt, n, gr):
        raw_x = torch.reshape(x, (-1,))
        org_len = len(raw_x)
        if org_len % gr:
            vacant_num = gr - org_len % gr
            raw_x = F.pad(raw_x, (0, vacant_num), 'constant', 0)
        raw_x = raw_x.view(-1, gr)
        max_dim1, _ = raw_x.max(dim=1)

        if dt == 'weight':

            for b in range(n, 8):
                mul_xth = 2 ** (b - 1)
                round_value = 2 ** (b + 1 - n)

                outlier_id = (max_dim1 >= mul_xth) & (max_dim1 < mul_xth * 2)

                cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4))

                if outlier_id.any():
                    threshold = 2 * mul_xth - 2 ** (b - 4)
                    indices_both = outlier_id & cond2
                    indices_round = outlier_id & (~cond2)

                    if indices_both.any():
                        selected_raw_x = raw_x[indices_both]
                        element_floor = selected_raw_x >= threshold
                        element_round = selected_raw_x < threshold
                        selected_raw_x[element_floor] = torch.floor(selected_raw_x[element_floor] / round_value) * round_value
                        selected_raw_x[element_round] = torch.round(selected_raw_x[element_round] / round_value) * round_value
                        raw_x[indices_both] = selected_raw_x

                    if indices_round.any():
                        raw_x[indices_round] = torch.round(raw_x[indices_round] / round_value) * round_value
        
        elif dt == 'act':
            for b in range(n, 16):
                mul_xth = 2 ** (b - 1)
                round_value = 2 ** (b + 1 - n)
                outlier_id = (max_dim1 >= mul_xth) & (max_dim1 < mul_xth * 2)
                cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4))

                if outlier_id.any():
                    threshold = 2 * mul_xth - 2 ** (b - 4)
                    indices_both = outlier_id & cond2
                    indices_round = outlier_id & (~cond2)

                    if indices_both.any():
                        selected_raw_x = raw_x[indices_both]
                        element_floor = selected_raw_x >= threshold
                        element_round = selected_raw_x < threshold
                        selected_raw_x[element_floor] = torch.floor(selected_raw_x[element_floor] / round_value) * round_value
                        selected_raw_x[element_round] = torch.round(selected_raw_x[element_round] / round_value) * round_value
                        raw_x[indices_both] = selected_raw_x
                        
                    if indices_round.any():
                        raw_x[indices_round] = torch.round(raw_x[indices_round] / round_value) * round_value
  
            for b in range(n, 16):
                mul_xth = 2 ** (b - 1)
                round_value = 2 ** (b + 1 - n)

                outlier_id = (max_dim1 >= mul_xth) & (max_dim1 < mul_xth * 2)

                cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4))

                if outlier_id.any():
                    threshold = 2 * mul_xth - 2 ** (b - 4)
                    indices_both = outlier_id & cond2
                    indices_round = outlier_id & (~cond2)

                    if indices_both.any():
                        selected_raw_x = raw_x[indices_both]
                        element_floor = selected_raw_x >= threshold
                        element_round = selected_raw_x < threshold
                        selected_raw_x[element_floor] = torch.floor(selected_raw_x[element_floor] / round_value) * round_value
                        selected_raw_x[element_round] = torch.round(selected_raw_x[element_round] / round_value) * round_value
                        raw_x[indices_both] = selected_raw_x

                    if indices_round.any():
                        raw_x[indices_round] = torch.round(raw_x[indices_round] / round_value) * round_value




        elif dt == 'key':
            for b in range(n, 8):
                mul_xth = 2 ** (b - 1)
                round_value = 2 ** (b + 1 - n)

                outlier_id = (max_dim1 >= mul_xth) & (max_dim1 < mul_xth * 2)

                cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4))

                if outlier_id.any():
                    threshold = 2 * mul_xth - 2 ** (b - 4)
                    indices_both = outlier_id & cond2
                    indices_round = outlier_id & (~cond2)

                    if indices_both.any():
                        selected_raw_x = raw_x[indices_both]
                        element_floor = selected_raw_x >= threshold
                        element_round = selected_raw_x < threshold
                        selected_raw_x[element_floor] = torch.floor(selected_raw_x[element_floor] / round_value) * round_value
                        selected_raw_x[element_round] = torch.round(selected_raw_x[element_round] / round_value) * round_value
                        raw_x[indices_both] = selected_raw_x

                    if indices_round.any():
                        raw_x[indices_round] = torch.round(raw_x[indices_round] / round_value) * round_value


        else:
            for b in range(n, 8):
                mul_xth = 2 ** (b - 1)
                round_value = 2 ** (b + 1 - n)

                outlier_id = (max_dim1 >= mul_xth) & (max_dim1 < mul_xth * 2)

                cond2 = max_dim1 >= (2 * mul_xth - 2 ** (b - 4))

                if outlier_id.any():
                    threshold = 2 * mul_xth - 2 ** (b - 4)
                    indices_both = outlier_id & cond2
                    indices_round = outlier_id & (~cond2)

                    if indices_both.any():
                        selected_raw_x = raw_x[indices_both]
                        element_floor = selected_raw_x >= threshold
                        element_round = selected_raw_x < threshold
                        selected_raw_x[element_floor] = torch.floor(selected_raw_x[element_floor] / round_value) * round_value
                        selected_raw_x[element_round] = torch.round(selected_raw_x[element_round] / round_value) * round_value
                        raw_x[indices_both] = selected_raw_x

                    if indices_round.any():
                        raw_x[indices_round] = torch.round(raw_x[indices_round] / round_value) * round_value

        raw_x = raw_x.view(-1)
        x = raw_x[:org_len].view_as(x)

        return x


class IntQuant(Function):
    @staticmethod
    def forward(ctx, x, S, N):
        level = 2**(N - 1) - 1
        out = torch.round(x * S)
        out = torch.clamp(out, max=level)
        return out


class quantizer(nn.Module):
    def __init__(self, channels, qinfo, device='cuda'):
        super(quantizer, self).__init__()
        
        self.qinfo = qinfo
        self.device = device
        
        if self.qinfo.data == 'weight':
            self.observer = NormalMinMaxObserver(channels)
        elif self.qinfo.data == 'act':
            self.observer = NormalMinMaxObserver(channels)
        else:
            self.observer = NormalMinMaxObserver(0)

        self.register_buffer('scale', torch.ones_like((self.observer.max_val), dtype=torch.float32))
        self.register_buffer('zero_point', torch.zeros_like((self.observer.max_val), dtype=torch.float32))


    def update_quant_params(self):

        if self.qinfo.qm<=5:
            quant_range = 2**(self.qinfo.n-1) - 1
            data_range = torch.max(torch.abs(self.observer.max_val), torch.abs(self.observer.min_val))
            self.scale = quant_range/data_range

        elif self.qinfo.qm==6:
            bias = 2**(self.qinfo.e-1)-1
            quant_range = 2**((2**self.qinfo.e-2) - bias)
            data_range = torch.max(torch.abs(self.observer.max_val), torch.abs(self.observer.min_val))
            self.scale = quant_range/data_range 

        elif self.qinfo.qm==7:        
            quant_range = 2**((2**self.qinfo.e)) 
            data_range = torch.max(torch.abs(self.observer.max_val), torch.abs(self.observer.min_val))
            self.scale = quant_range/data_range 
        
        elif self.qinfo.qm==8:
            quant_range = 2**((2**self.qinfo.e-1)) 
            data_range = torch.max(torch.abs(self.observer.max_val), torch.abs(self.observer.min_val))
            self.scale = quant_range/data_range 

        self.zero_point = torch.zeros_like(self.scale)

    def truncate(self, x, bitwidth):
        dev_factor = 2 ** (8-bitwidth)
        x = torch.round(x / dev_factor) * dev_factor
        return x

    def clipping(self, x, sign, bitwidth):       
        if sign == enums.signed:
            maxa = 2**(bitwidth-1)
        else:
            maxa = 2**bitwidth       
        x = torch.clamp(x, max=maxa-1)
        return x

    def forward(self, x):
        x=x.to(self.device)
        self.scale = self.scale.to(x.device)
        if self.qinfo.phase == 0:
            return x
        elif self.qinfo.phase == 1:
            self.observer(x)
            return x
        elif self.qinfo.phase == 2:
            if self.qinfo.qm<=5:
                sign = torch.sign(x)
                out = torch.abs(x)
                out = IntQuant.apply(out, self.scale, self.qinfo.n)
                
                if self.qinfo.QRn > 0:
                    if self.qinfo.qm == 3:
                        out = MulQRQuant.apply(out, self.qinfo.data, self.qinfo.QRn, self.qinfo.QRg)

                out = (out*sign)/self.scale
        return out

class IntQLinear(nn.Linear):
    def __init__(self, ic, oc, bias=True, is_qw=True, is_qa=True):
        super(IntQLinear, self).__init__(ic, oc, bias)
        self.is_qw = is_qw
        self.is_qa = is_qa
        self.qinfoa = QInfo(phase=opt.qphase, data='act', qm=opt.qm, n=opt.qna, e=opt.qe, QRn=opt.QRa, QRg=opt.QRg, nosub=opt.qnosub)
        if is_qa:
            self.quantA = quantizer(0, qinfo=self.qinfoa)
        self.qinfow = QInfo(phase=opt.qphase, data='weight', qm=opt.qm, n=opt.qnw, e=opt.qe, QRn=opt.QRw, QRg=opt.QRg, nosub=opt.qnosub)
        if is_qw:
            self.quantW = quantizer(oc, qinfo=self.qinfow)

    def forward(self, x):
        
        if self.is_qa and self.qinfoa.qm:
            aquant = self.quantA(x)
        else:
            aquant = x
        
        if self.is_qw and self.qinfow.qm:
            wquant = self.quantW(self.weight)        
        else:
            wquant = self.weight
        
        out = F.linear(aquant, wquant, self.bias)

        return out

class IntQEmbedding(nn.Embedding):
    def __init__(self, num, dim, pd=None, is_qw=True):
        super(IntQEmbedding, self).__init__(num, dim, pd)
        self.is_qw = is_qw

        self.qinfow = QInfo(phase=opt.qphase, data='weight', qm=opt.qm, n=16, e=opt.qe, nosub=opt.qnosub)      
        if is_qw:
            self.quantW = quantizer(0, qinfo=self.qinfow)

    def forward(self, x):

        if self.is_qw and self.qinfow.qm:
            wquant = self.quantW(self.weight)        
        else:
            wquant = self.weight
        
        out = F.embedding(x, wquant, self.padding_idx)

        return out

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout_p=0.0, quantize_attn=True, device='cuda'):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout_p = dropout_p
        self.quantize_attn = quantize_attn
        
        if self.quantize_attn:
            self.qinfoa = QInfo(phase=opt.qphase, data='act', qm=opt.qm, n=opt.qna, e=opt.qe, QRn=opt.QRa, QRg=opt.QRg, nosub=opt.qnosub)
            self.quantA = quantizer(0, qinfo=self.qinfoa)

    def forward(self, query, key, value, attn_mask=None, is_causal=False):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=query.dtype))

        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))

        if is_causal:
            causal_mask = torch.tril(torch.ones(scores.size(-2), scores.size(-1), device=query.device)).unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype)

        if self.quantize_attn:
            attn_weights = self.quantA(attn_weights)

        if self.dropout_p > 0.0:
            attn_weights = F.dropout(attn_weights, p=self.dropout_p, training=self.training)

        attn_output = torch.matmul(attn_weights, value)
        return attn_output
