import torch
from torch import nn

from entropy_utils import search_scale_entropy, get_entropy_bits, get_huffman_bits
from gptq_triton.accumulate_hessian import accumulate_hessian
from gptq_triton.gptq_loop import gptq_loop
from gptq_triton.min_pivot import min_pivot_order
from gptq_triton.quantize import find_quantization_meta, mse_scale, dequantize, quantize


class HessianUtil:
    def __init__(self) -> None:
        super().__init__()
        self.hessian_orig: torch.Tensor | None = None  # (C, C), hessian matrix H = X.t() @ X
        self.hessian: torch.Tensor | None = None  # (C, C), hessian matrix H, permuted by act_order
        self.hessian_inv: torch.Tensor | None = None  # (C, C), upper Cholesky of H^{-1} permuted by act_order
        self.perm: torch.Tensor | None = None  # (C), permutation of act_order
        self.perm_inv: torch.Tensor | None = None  # (C), recover permutation of act_order

    @torch.no_grad()
    def add_batch(self, inp: torch.Tensor, hessian_dtype: torch.dtype = torch.float32, use_kernel: bool = True) -> None:
        """
        Add a batch of input to the Hessian matrix
        inp: (..., C)
        """
        assert self.hessian_inv is None, 'The Hessian has already been inverted. No updates allowed.'
        assert not hasattr(inp, 'fake_mode'), 'You should not use fake tensors as input.'
        if self.hessian is None:
            self.hessian: torch.Tensor = torch.zeros(inp.size(-1), inp.size(-1), dtype=hessian_dtype, device=inp.device)
        if inp.dim() <= 1:
            inp: torch.Tensor = inp.reshape(1, -1)
        accumulate_hessian(self.hessian, inp.flatten(end_dim=-2), save_lower_only=False, debug_mode=not use_kernel)

    @torch.no_grad()
    def invert(self, order: str = 'act', damp_ratio: float = 1e-2) -> None:
        """
        Invert the Hessian matrix (compute the Cholesky decomposition of the inverse)
        """
        assert self.hessian is not None, 'The Hessian is not ready to be inverted. Please call add_batch first.'
        # assert self.hessian_inv is None, 'The Hessian has already been inverted.'
        device: torch.device = self.hessian.device

        if self.hessian_orig is None:
            self.hessian_orig: torch.Tensor = self.hessian.clone()
        else:
            self.hessian: torch.Tensor = self.hessian_orig.clone()

        match order.lower():
            case None | '' | 'none':
                self.perm: torch.Tensor = torch.arange(self.hessian.size(-1), device=device)
            case 'right2left':
                self.perm: torch.Tensor = torch.arange(self.hessian.size(-1) - 1, -1, -1, device=device)
            case 'random':
                self.perm: torch.Tensor = torch.randperm(self.hessian.size(-1), device=device)
            case 'act':
                self.perm: torch.Tensor = self.hessian.diagonal().argsort(descending=True)
            case 'min_pivot':
                self.perm: torch.Tensor = min_pivot_order(hessian=self.hessian, direct=False).flip(dims=(-1,))
            case _:
                raise NotImplementedError
        self.perm_inv: torch.Tensor = self.perm.argsort(descending=False)
        self.hessian = self.hessian[self.perm][:, self.perm]

        diag_indices: torch.Tensor = torch.arange(len(self.hessian), device=device)
        damp: torch.Tensor = damp_ratio * self.hessian.diagonal().mean()

        self.hessian_inv: torch.Tensor = torch.empty_like(self.hessian)
        info: torch.Tensor = torch.empty((), dtype=torch.int32, device=device)
        max_try: int = 100
        while (max_try := max_try - 1) >= 0:
            self.hessian[diag_indices, diag_indices] += damp
            torch.linalg.cholesky_ex(self.hessian, upper=False, check_errors=False, out=(self.hessian_inv, info))
            torch.cholesky_inverse(self.hessian_inv, upper=False, out=self.hessian_inv)
            torch.linalg.cholesky_ex(self.hessian_inv, upper=True, check_errors=False, out=(self.hessian_inv, info))
            if not self.hessian_inv.isnan().any():
                break
        assert max_try >= 0, 'Hessian inversion failed. Please try using more samples or increasing damp_ratio.'
        # self.hessian = None
        self.hessian_inv /= self.hessian_inv.diagonal()[:, None].contiguous()  # normalize the diagonal to 1 so that we do not need to divide it in GPTQ loops later


@torch.no_grad()
def get_quantization_grid(
        weight: torch.Tensor,
        quant_group_size: int,
        quant_bit_width: float,
        quant_use_entropy_mode: str = 'none',
        quant_symmetric: bool = False,
        quant_dtype: torch.dtype = None,
        quant_vertical: bool = False,
        quant_use_mse: bool = True,
        quant_max_shrink: float = .8,
        quant_n_grid: int = 100,
        quant_norm: float = 2.4,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Get the quantization grid for the weight matrix
    weight: (..., (R), C)
    scale: (..., (R), C)
    qzero: (..., (R), C)
    maxq: ()
    """
    dtype, device = weight.dtype, weight.device
    if quant_vertical:
        weight = weight.transpose(-2, -1)  # (..., C, R)
    weight = weight.unflatten(dim=-1, sizes=(-1, quant_group_size))  # (..., G, gs)

    match quant_use_entropy_mode.lower():
        case None | '' | 'none':
            scale, qzero, maxq = find_quantization_meta(
                x=weight,
                bit_width='ternary' if 1. < quant_bit_width < 2. else int(round(quant_bit_width)),
                symmetric=quant_symmetric,
                dtype=quant_dtype,
            )  # (..., G), (..., G), ()
            if quant_use_mse:
                mse_scale(
                    x=weight.contiguous(),  # (..., G, gs)
                    p=1. - torch.linspace(0., quant_max_shrink, quant_n_grid, dtype=dtype, device=device),  # (..., P)
                    scale=scale,  # (..., G)
                    qzero=qzero,  # (..., G)
                    maxq=maxq,  # ()
                    dtype=quant_dtype,
                    norm=quant_norm,
                    debug_mode=False,
                )
        case 'grouped_e':
            scale = search_scale_entropy(weight, target_entropy=quant_bit_width, dim=-1)  # (..., G)
            qzero = torch.zeros_like(scale)  # (..., G)
            maxq = torch.as_tensor(torch.nan, dtype=dtype, device=device)  # ()
        case 'all_e':
            scale = search_scale_entropy(weight.flatten(), target_entropy=quant_bit_width, dim=-1).expand(*weight.shape[:-1])  # (..., G)
            qzero = torch.zeros_like(scale)  # (..., G)
            maxq = torch.as_tensor(torch.nan, dtype=dtype, device=device)  # ()
        case _:
            raise NotImplementedError

    scale = scale.repeat_interleave(quant_group_size, dim=-1)  # (..., C)
    qzero = qzero.repeat_interleave(quant_group_size, dim=-1)  # (..., C)

    weight = weight.flatten(start_dim=-2)  # (..., C)
    if quant_vertical:
        weight = weight.transpose(-2, -1)  # (..., R, C)
        scale = scale.transpose(-2, -1)  # (..., R, C)
        qzero = qzero.transpose(-2, -1)  # (..., R, C)

    assert weight.shape == scale.shape == qzero.shape and maxq.shape == ()
    return scale, qzero, maxq  # (..., (R), C), (..., (R), C), ()


@torch.no_grad()
def gptq_quantize(
        weight: torch.Tensor,
        hessian_util: HessianUtil | None,
        quant_group_size: int,
        quant_bit_width: float,
        quant_use_entropy_mode: str = 'none',
        quant_symmetric: bool = False,
        quant_do_clip: bool = True,
        quant_use_mse: bool = True,
        quant_max_shrink: float = .8,
        quant_n_grid: int = 100,
        quant_norm: float = 2.4,
        save_device: torch.device = torch.device('cpu'),
        do_rtn: bool = False,
) -> dict[str, dict]:
    """
    Quantize the weight matrix using GPTQ
    weight: (R, C)
    """
    dtype, device = weight.dtype, weight.device
    n_out_features, n_in_features = weight.shape
    weight_orig: torch.Tensor = weight  # (R, C)
    h_dtype: torch.dtype = hessian_util.hessian_inv.dtype

    if quant_use_entropy_mode not in ['strict_e', 'strict_h']:
        weight: torch.Tensor = weight.to(dtype=h_dtype, copy=True)  # (R, C)
        if quant_group_size <= 0:
            quant_group_size: int = n_in_features
            quant_use_mse: bool = False  # TODO: triton mse runs very slowly
        scale, qzero, maxq = get_quantization_grid(
            weight=weight,
            quant_group_size=quant_group_size,
            quant_bit_width=quant_bit_width,
            quant_use_entropy_mode=quant_use_entropy_mode,
            quant_symmetric=quant_symmetric,
            quant_dtype=dtype,
            quant_vertical=False,
            quant_use_mse=quant_use_mse,
            quant_max_shrink=quant_max_shrink,
            quant_n_grid=quant_n_grid,
            quant_norm=quant_norm,
        )  # (R, C), (R, C), ()
        if do_rtn:
            qweight = quantize(weight.clone(), scale, qzero, maxq)  # (R, C)
            weight = dequantize(qweight.clone(), scale, qzero, dtype)  # (R, C)
        else:
            qweight, weight = gptq_loop(
                weight=weight.transpose(-2, -1)[hessian_util.perm],  # (C, R)
                hessian_inv=hessian_util.hessian_inv,  # (C, C)
                scale=scale.transpose(-2, -1)[hessian_util.perm],  # (C, R)
                qzero=qzero.transpose(-2, -1)[hessian_util.perm],  # (C, R)
                maxq=maxq if quant_do_clip and maxq.isfinite() else None,  # ()
                dtype=dtype,
                gptq_block_size=128,
                debug_mode=False,
            )  # (C, R), (C, R)
    else:
        quant_group_size: int = n_in_features
        weight_perm: torch.Tensor = weight_orig.to(dtype=h_dtype).transpose(-2, -1)[hessian_util.perm]  # (C, R)
        qzero: torch.Tensor = torch.zeros(n_in_features, n_out_features, dtype=h_dtype, device=device)  # (C, R)

        std: float = torch.linalg.vector_norm(weight).item() * weight.numel() ** -.5
        s_min, s_max = 0., std * 3.
        max_iter = 10
        tol = 1e-3
        for i in range(max_iter + 1):
            s_mid: float = (s_min + s_max) * .5 if i < max_iter else s_max
            scale: torch.Tensor = torch.full((n_in_features, n_out_features), s_mid, dtype=h_dtype, device=device)  # (C, R)

            if do_rtn:
                qweight: torch.Tensor = quantize(weight, scale.transpose(-2, -1), qzero.transpose(-2, -1), None)  # (R, C)
                _: torch.Tensor = dequantize(qweight, scale.transpose(-2, -1), qzero.transpose(-2, -1), dtype)  # (R, C)
            else:
                # GPTQ quantization
                qweight, weight = gptq_loop(
                    weight=weight_perm.clone(),  # (C, R)
                    hessian_inv=hessian_util.hessian_inv,  # (C, C)
                    scale=scale,  # (C, R)
                    qzero=qzero,  # (C, R)
                    maxq=None,  # ()
                    dtype=dtype,
                    gptq_block_size=128,
                    debug_mode=False,
                )  # (C, R), (C, R)

            match quant_use_entropy_mode:
                case 'strict_e':
                    entropy: float = get_entropy_bits(qweight).item()
                case 'strict_h':
                    entropy: float = get_huffman_bits(qweight)
                case _:
                    raise NotImplementedError
            # if do_rtn and abs(entropy - quant_bit_width) < tol:
            #     break
            if entropy < quant_bit_width:
                s_max: float = s_mid
            else:
                s_min: float = s_mid
        scale = scale.view(n_out_features, n_in_features)  # (R, C)
        qzero = qzero.view(n_out_features, n_in_features)  # (R, C)

    if not do_rtn:
        weight: torch.Tensor = weight.transpose(-2, -1)[:, hessian_util.perm_inv]  # (R, C)
        qweight: torch.Tensor = qweight.transpose(-2, -1)[:, hessian_util.perm_inv]  # (R, C)

    quant_meta: dict[str, torch.Tensor] = {
        # 'weight': weight.contiguous().to(device=save_device),  # (R, C)
        'qweight': qweight.contiguous().to(dtype=torch.int16, device=save_device),  # (R, C)
        'scale': scale[:, ::quant_group_size].contiguous().to(dtype=dtype, device=save_device),  # (R, G)
        'qzero': qzero[:, ::quant_group_size].contiguous().to(dtype=torch.int16, device=save_device),  # (R, G)
        'quant_group_sizes': torch.full((n_in_features // quant_group_size,), quant_group_size, dtype=torch.int16, device=save_device),  # (G)
    }

    delta: torch.Tensor = reconstruct_nn_linear(
        quant_meta=quant_meta,
        bias=None,
        dtype=dtype,
        device=weight.device,
    ).weight.data.to(dtype=h_dtype) - weight_orig.to(dtype=h_dtype)  # (R, C)

    metrics: dict = {
        'shape': tuple(weight.shape),
        'error_w2': torch.linalg.vector_norm(delta).item() ** 2. / delta.numel(),
        'error_whw': (delta[..., None, :] @ hessian_util.hessian_orig @ delta[..., None]).mean().item() / hessian_util.hessian_orig.size(-1),
        'violation_ratio_qw': ((quant_meta['qweight'] - quant_meta['qzero'].repeat_interleave(quant_group_size, dim=-1) < - 2 ** (quant_bit_width - 1)) | (quant_meta['qweight'] - quant_meta['qzero'].repeat_interleave(quant_group_size, dim=-1) >= 2 ** (quant_bit_width - 1))).mean(dtype=h_dtype).item(),
        'range_qw': (quant_meta['qweight'].min().item(), quant_meta['qweight'].max().item()),
        'entropy_qw': get_entropy_bits(quant_meta['qweight']).item(),
        'huffman_qw': get_huffman_bits(quant_meta['qweight']),
        'distribution_qw': {s.item(): c.item() for s, c in zip(*quant_meta['qweight'].unique(sorted=True, return_counts=True))},
    }

    return {'quant_meta': quant_meta, 'metrics': metrics}


@torch.no_grad()
def gptq_outliers_quantize(
        weight: torch.Tensor,
        hessian_util: HessianUtil | None,
        quant_group_size: int,
        quant_bit_width: float,
        quant_use_entropy_mode: str = 'none',
        quant_symmetric: bool = False,
        quant_outlier_percentage: float = .05,
        quant_use_mse: bool = True,
        quant_max_shrink: float = .8,
        quant_n_grid: int = 100,
        quant_norm: float = 2.4,
        save_device: torch.device = torch.device('cpu'),
) -> dict[str, dict]:
    """
    Quantize the weight matrix using GPTQ
    weight: (R, C)
    """
    dtype, device = weight.dtype, weight.device
    n_out_features, n_in_features = weight.shape
    weight_orig: torch.Tensor = weight  # (..., R, C)
    h_dtype: torch.dtype = hessian_util.hessian_inv.dtype

    perm_flip: torch.Tensor = hessian_util.perm.flip(dims=(-1,))  # (C)
    perm_flip_inv: torch.Tensor = perm_flip.argsort(descending=False)  # (C)
    hessian: torch.Tensor = hessian_util.hessian_orig[..., perm_flip, :][..., perm_flip]  # (..., C, C)
    diag_indices: torch.Tensor = torch.arange(n_in_features, device=device)  # (C)
    damp_ratio: float = 1e-2

    max_try: int = 100
    while (max_try := max_try - 1) >= 0:
        hessian[..., diag_indices, diag_indices] += damp_ratio * hessian.diagonal(dim1=-2, dim2=-1).mean(dim=-1)  # (..., C)
        basis_in, info = torch.linalg.cholesky_ex(hessian, upper=True, check_errors=False)  # (..., C, C), upper triangular, column vectors, B^T B = H
        if not info.to(dtype=torch.bool).any():
            break
    assert max_try >= 0, 'Hessian inversion failed. Please try using more samples or increasing damp_ratio.'

    if quant_group_size <= 0:
        quant_group_size: int = n_in_features
        quant_use_mse: bool = False  # TODO: triton mse runs very slowly

    weight: torch.Tensor = weight.to(dtype=h_dtype, copy=True).contiguous()  # (..., R, C)
    scale, qzero, maxq = get_quantization_grid(
        weight=weight,
        quant_group_size=quant_group_size,
        quant_bit_width=quant_bit_width,
        quant_use_entropy_mode=quant_use_entropy_mode,
        quant_symmetric=quant_symmetric,
        quant_dtype=dtype,
        quant_vertical=False,
        quant_use_mse=quant_use_mse,
        quant_max_shrink=quant_max_shrink,
        quant_n_grid=quant_n_grid,
        quant_norm=quant_norm,
    )  # (R, C), (R, C), ()

    weight: torch.Tensor = weight[..., perm_flip].transpose(-2, -1)  # (..., C, R)
    scale: torch.Tensor = scale[..., perm_flip].transpose(-2, -1)  # (..., C, R)
    qzero: torch.Tensor = qzero[..., perm_flip].transpose(-2, -1)  # (..., C, R)
    weight_dq: torch.Tensor = torch.empty_like(weight)  # (..., C, R)
    qweight: torch.Tensor = torch.empty_like(weight)  # (..., C, R)
    oweight: torch.Tensor = torch.empty_like(weight)  # (..., C, R)

    sscale_min = torch.zeros(*scale.shape[:-2], 1, n_out_features, dtype=h_dtype, device=device)  # (..., 1, R)
    sscale_max = torch.full_like(sscale_min, 2.)  # (..., 1, R), this should not be too large!
    max_iter: int = 10

    for it in range(max_iter + 1):
        if it < max_iter:
            sscale = (sscale_min + sscale_max) * .5  # (..., 1, R)
        else:
            sscale = sscale_max  # (..., 1, R)

        y = basis_in @ weight  # (..., C, R), column vectors

        for i1 in range((n_in_features - 1) // quant_group_size * quant_group_size, -1, -quant_group_size):
            i2: int = i1 + quant_group_size

            basis_chunk: torch.Tensor = basis_in[..., i1:i2, i1:i2]  # (..., G, G), upper triangular, column vectors, B^T B = H[i1:i2, i1:i2]
            y_chunk: torch.Tensor = y[..., i1:i2, :]  # (..., G, R)
            w_chunk_dq: torch.Tensor = weight_dq[..., i1:i2, :]  # (..., G, R)
            qw_chunk: torch.Tensor = qweight[..., i1:i2, :]  # (..., G, R)
            scale_chunk = scale[..., i1:i2, :] * sscale  # (..., G, R)
            qzero_chunk = qzero[..., i1:i2, :]  # (..., G, R)
            ow_chunk: torch.Tensor = oweight[..., i1:i2, :]  # (..., G, R)

            for j1 in range(quant_group_size - 1, -1, -1):
                j2 = j1 + 1
                w_vec: torch.Tensor = y_chunk[..., j1:j2, :] / basis_chunk[..., j1:j2, j1:j2]  # (..., 1, R)
                w_vec_int: torch.Tensor = (w_vec / scale_chunk[..., j1:j2, :] + qzero_chunk[..., j1:j2, :]).round()  # (..., 1, R)
                qw_chunk[..., j1:j2, :] = w_vec_int.clamp(0., maxq)  # (..., 1, R)
                w_vec_dq: torch.Tensor = (qw_chunk[..., j1:j2, :] - qzero_chunk[..., j1:j2, :]) * scale_chunk[..., j1:j2, :]  # (..., 1, R)
                w_chunk_dq[..., j1:j2, :] = torch.where(w_vec_int == qw_chunk[..., j1:j2, :], w_vec_dq, w_vec)  # (..., 1, R)
                ow_chunk[..., j1:j2, :] = w_chunk_dq[..., j1:j2, :] - w_vec_dq  # (..., 1, R)
                y_chunk[..., :j1, :] -= basis_chunk[..., :j1, j1:j2] @ w_chunk_dq[..., j1:j2, :]  # (..., ?, R)

            y[..., :i1, :] -= basis_in[..., :i1, i1:i2] @ w_chunk_dq  # (..., ?, R)

        outlier_count = oweight.to(dtype=torch.bool).sum(dim=-2, keepdim=True)  # (..., 1, R)
        mask = outlier_count <= quant_outlier_percentage * n_in_features  # (..., 1, R)
        sscale_min = torch.where(mask, sscale_min, sscale)  # (..., 1, R)
        sscale_max = torch.where(mask, sscale, sscale_max)  # (..., 1, R)

    weight = weight_dq.transpose(-2, -1)[..., perm_flip_inv].to(dtype=dtype)  # (..., R, C)
    qweight = qweight.transpose(-2, -1)[..., perm_flip_inv]  # (..., R, C)
    scale = (scale * sscale_max).transpose(-2, -1)[..., perm_flip_inv]  # (..., R, C)
    qzero = qzero.transpose(-2, -1)[..., perm_flip_inv]  # (..., R, C)
    oweight = oweight.transpose(-2, -1)[..., perm_flip_inv]  # (..., R, C)

    quant_meta: dict[str, torch.Tensor] = {
        # 'weight': weight.contiguous().to(device=save_device),  # (R, C)
        'qweight': qweight.contiguous().to(dtype=torch.int16, device=save_device),  # (R, C)
        'scale': scale[:, ::quant_group_size].contiguous().to(dtype=dtype, device=save_device),  # (R, G)
        'qzero': qzero[:, ::quant_group_size].contiguous().to(dtype=torch.int16, device=save_device),  # (R, G)
        'quant_group_sizes': torch.full((n_in_features // quant_group_size,), quant_group_size, dtype=torch.int16, device=save_device),  # (G)
        'oweight': oweight.contiguous().to(dtype=dtype, device=save_device),  # (R, C)
    }

    delta: torch.Tensor = reconstruct_nn_linear(
        quant_meta=quant_meta,
        bias=None,
        dtype=dtype,
        device=weight.device,
    ).weight.data.to(dtype=h_dtype) - weight_orig.to(dtype=h_dtype)  # (R, C)

    metrics: dict = {
        'shape': tuple(weight.shape),
        'error_w2': torch.linalg.vector_norm(delta).item() ** 2. / delta.numel(),
        'error_whw': (delta[..., None, :] @ hessian_util.hessian_orig @ delta[..., None]).mean().item() / hessian_util.hessian_orig.size(-1),
        'violation_ratio_qw': ((quant_meta['qweight'] - quant_meta['qzero'].repeat_interleave(quant_group_size, dim=-1) < - 2 ** (quant_bit_width - 1)) | (quant_meta['qweight'] - quant_meta['qzero'].repeat_interleave(quant_group_size, dim=-1) >= 2 ** (quant_bit_width - 1))).mean(dtype=h_dtype).item(),
        'range_qw': (quant_meta['qweight'].min().item(), quant_meta['qweight'].max().item()),
        'outlier_percentages': oweight.to(dtype=torch.bool).sum().item() / oweight.numel(),
    }

    return {'quant_meta': quant_meta, 'metrics': metrics}


@torch.no_grad()
def construct_matrix(
        qweight: torch.Tensor,  # (R, C), int16
        scale: torch.Tensor,  # (R, G), float16 or bfloat16
        qzero: torch.Tensor,  # (R, G) or (G) or (), int16
        group_sizes: torch.Tensor,  # (G), int16
        oweight: torch.Tensor = None,  # (R, C) or (C) or (), float16 or bfloat16
) -> torch.Tensor:
    """
    Reconstruct matrix from quantizer information
    """
    qzero = qzero.expand(scale.shape)  # (R, G)
    group_ids: list[int] = [0] + group_sizes.cumsum(dim=-1).tolist()  # (G+1)
    weight = torch.empty_like(qweight, dtype=scale.dtype, device=qweight.device)  # (R, C)
    for k in range(len(group_ids) - 1):
        i1, i2 = group_ids[k], group_ids[k + 1]
        weight[:, i1:i2] = dequantize(qx=qweight[:, i1:i2], scale=scale[:, k:k+1], qzero=qzero[:, k:k+1])
    if oweight is not None:
        weight += oweight
    return weight  # (R, C)


@torch.no_grad()
def reconstruct_nn_linear(
        quant_meta: dict[str, torch.Tensor],
        bias: torch.Tensor | None,
        dtype: torch.dtype,
        device: torch.device = torch.device('cpu'),
) -> nn.Linear:
    """
    Reconstruct nn.Linear from quantizer information
    bias: (R)
    """
    if 'weight' in quant_meta:
        weight = quant_meta['weight'].to(dtype=dtype, device=device)
    else:
        qweight = quant_meta['qweight'].to(dtype=torch.int16, device=device)
        qzero = quant_meta['qzero'].to(dtype=torch.int16, device=device)
        group_sizes = quant_meta['quant_group_sizes'].to(dtype=torch.int16, device=device)
        scale = quant_meta['scale'].to(dtype=dtype, device=device)
        oweight = quant_meta['oweight'].to(dtype=dtype, device=device) if 'oweight' in quant_meta else None
        weight = construct_matrix(qweight, scale, qzero, group_sizes, oweight)
    nn_linear = nn.Linear(*weight.shape[::-1], bias=bias is not None, dtype=dtype, device=device)
    nn_linear.weight.copy_(weight)
    if bias is not None:
        nn_linear.bias.copy_(bias)
    nn_linear.eval()
    return nn_linear
