from nn_compression.cv import CvModel, cifar10
from nn_compression.coding import DeepCABAC
from nn_compression.quantisation import gptq_quantise_network
from pathlib import Path
import torch


def test_deep_cabac():
    net = CvModel.RESNET18_CIFAR10.load()
    net.train(False)
    cifar = cifar10()
    assert cifar.evaluate(net) > 0.9
    netq = gptq_quantise_network(net, 3, cifar.calibration_sample(1), inplace=True)
    coder = DeepCABAC("tmp.nnc")
    coder.encode(netq)

    for n, p in coder.decode().items():
        assert torch.allclose(p, netq.state_dict()[n])
    Path("tmp.nnc").unlink()
