import pytest
from nn_compression.cv import cifar10, CvModel
from nn_compression.quantisation._gptq import GptqLayer
from nn_compression.quantisation import (
    AffineGridQuantiser,
    gptq_quantise_network,
)
import torch


@pytest.mark.slow
def test_gptq_performance():
    net = CvModel.RESNET18_CIFAR10.load()
    net.train(False)
    cifar = cifar10()
    assert cifar.evaluate(net) > 0.9
    netq = gptq_quantise_network(net, 3, cifar.calibration_sample(10), inplace=True)
    assert cifar.evaluate(netq) > 0.8


def test_gptq_ab():
    x = torch.nn.Linear(1, 5)
    x.weight = torch.nn.Parameter(
        torch.tensor([-3.2, -2.1, 0.0, 1.5, 2.8]).reshape(x.weight.shape)
    )
    quantizer = AffineGridQuantiser(5)
    gptq = GptqLayer(x, quantizer)
    gptq.add_batch(torch.tensor([1.0]))
    gptq.fasterquant()
    assert torch.allclose(
        x.weight.flatten(),
        torch.tensor([-3.3032, -2.0645, 0.0, 1.4452, 2.8903]),
        atol=1e-3,
    )


# def test_gptq_per_row():
#    torch.manual_seed(2)
#    net = torch.nn.Linear(5, 50)
#    x = torch.randn(10, 5)
#    qnet = gptq_quantise_network(net, 2, x, per_row_grid=True)
#
#    rows = []
#    for col_idx in range(qnet.weight.shape[1]):
#        row = qnet.weight[:, col_idx]
#        assert len(row.unique()) == 3 or len(row.unique()) == 4
#        rows.append(row)
#    for i in range(1, len(rows)):
#        assert not torch.allclose(rows[i], rows[0])
#
