from nn_compression.quantisation._gptq_reimplementation import (
    Gptq as GptqReimplementation,
)
from nn_compression.quantisation._gptq import GptqLayer as GptqOriginal
from nn_compression.quantisation import AffineGridQuantiser

import torch
import torch.nn as nn
import copy


def test_obc_reimplementation():
    torch.manual_seed(0)
    x = torch.randn(100, 20)
    layer = nn.Linear(20, 50)
    quant = AffineGridQuantiser(2)

    gptq_orig = GptqOriginal(copy.deepcopy(layer), quant)
    gptq_orig.add_batch(x)
    gptq_reimpl = GptqReimplementation(quant, 0.01)

    gptq_orig.fasterquant()
    quant_reimpl = gptq_reimpl.quantize(layer, x)

    quant_orig = gptq_orig.layer.weight.data
    quant_reimpl = quant_reimpl.weight.data

    assert torch.allclose(quant_orig, quant_reimpl)
