import torch
from nn_compression.quantisation import (
    AffineGridQuantiser,
    PerRowGridQuantiser,
    RdQuantiserDeepCabac,
)
from nn_compression.cv import CvModel, cifar10
import pytest


@pytest.fixture
def w():
    model = CvModel("resnet18_cifar10").load()
    return model.fc.weight.data.detach()


def test_lm0_equivalence(w):
    q = RdQuantiserDeepCabac(4, 0)
    q2 = AffineGridQuantiser(4)

    q.find_params(w)
    q2.find_params(w)
    pv = torch.tensor(2)

    for j in range(w.shape[1]):
        wq = q.quantise_with_uncertainty(w[:, j], pv, col=j)
        wq2 = q2.quantise_with_uncertainty(w[:, j], pv, col=j)
        x1 = wq.x
        x2 = wq2.x

        assert torch.allclose(x1, x2)


def test_lm0_equivalence_rowwise(w):
    q = RdQuantiserDeepCabac(4, 0, per_row_grid=True)
    q2 = AffineGridQuantiser(4, per_row_grid=True)

    q.find_params(w)
    q2.find_params(w)
    pv = torch.tensor(2)

    for j in range(w.shape[1]):
        if j == 19:
            pass
        wq = q.quantise_with_uncertainty(w[:, j], pv, col=j)
        wq2 = q2.quantise_with_uncertainty(w[:, j], pv, col=j)
        x1 = wq.x
        x2 = wq2.x

        assert torch.allclose(x1, x2), f"Column {j} is not equal"
