import copy
from torch import nn
from torch.utils.data import Dataset, TensorDataset
from data_utils.arrays import take_batches
from nn_compression.quantisation import (
    AffineGridQuantiser,
)
from nn_compression.quantisation._gptq import GptqLayer
import torch

from nn_compression.quantisation._network_quantisation import gptq_quantise_network


# naive implementation, only possible for sequential nets
def quantize_gtpq_sequential(net: nn.Sequential, nbits: int, x_cal: torch.Tensor):
    quant_net = copy.deepcopy(net)

    quantizer = AffineGridQuantiser(nbits)
    for layer, orig_layer in zip(quant_net, net):
        if hasattr(layer, "weight"):
            gptq = GptqLayer(layer, quantizer)
            gptq.add_batch(x_cal)
            gptq.fasterquant()
        x_cal = orig_layer(
            x_cal
        )  # An alternative could be to use the forward method of the currently quantised net.
    return quant_net


def test_general_gptq_linear():
    net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 3))
    net.train(False)

    nbit = 4
    samples = 5

    x_cal = torch.rand(samples, 2)

    seq_net = quantize_gtpq_sequential(net, 4, x_cal)
    general_net = gptq_quantise_network(net, nbit, x_cal)

    for seq_layer, general_layer in zip(seq_net, general_net):
        if hasattr(seq_layer, "weight") or hasattr(general_layer, "weight"):
            assert torch.allclose(seq_layer.weight, general_layer.weight)


def test_general_gptq_conv():
    net = nn.Sequential(nn.Conv2d(2, 10, 3), nn.ReLU(), nn.Conv2d(10, 3, 3))
    net.train(False)

    nbit = 4
    samples = 5

    x_cal = torch.rand(samples, 2, 21, 21)

    seq_net = quantize_gtpq_sequential(net, 4, x_cal)
    general_net = gptq_quantise_network(net, nbit, x_cal)

    for seq_layer, general_layer in zip(seq_net, general_net):
        if hasattr(seq_layer, "weight") or hasattr(general_layer, "weight"):
            assert torch.allclose(seq_layer.weight, general_layer.weight)
