import torch
from nn_compression.networks import LayerWiseHessian
from nn_compression.quantisation._interfaces import Quantiser
from nn_compression._interfaces import Quantisable


class GptqLayer:
    """Performs GPTQ quantisation on a single layer of a neural network. The layer is modified in place."""

    def __init__(self, layer: Quantisable, quantizer: Quantiser):
        """Initialise the GPTQ quantiser for a layer of a neural network.
        Currently only works on nn.Linear layers."""
        assert isinstance(layer, Quantisable), "Layer must be a quantisable layer."
        self.layer = layer

        if hasattr(layer, "hessian"):
            self.hessian = layer.hessian
        else:
            self.hessian = LayerWiseHessian(layer)
            # raise ValueError("Layer must have a Hessian matrix before we quantize.")

        self.W = self.hessian.get_weights()
        self.columns = self.W.shape[1]

        self.quantizer = quantizer

    def add_batch(self, x: torch.Tensor):
        """Add a batch of data to the GPTQ quantiser. The batch is used to estimate the Hessian matrix."""
        if torch.allclose(x, torch.zeros_like(x)):
            return
        self.hessian.add_batch(x)

    def fasterquant(
        self,
        blocksize=128,
    ):
        """Quantise the weights of a linear layer using GPTQ, in-place.

        Args:
            blocksize: The number of columns to quantise at once.
            percdamp: The percentage of the mean of the diagonal of the Hessian to add to the diagonal of the Hessian.
        """
        device = self.W.device
        self.quantizer.find_params(self.W, weight=True)
        Hinv_chol = self.hessian.cholesky_inverse

        Q = torch.zeros_like(self.W, device=device)

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            W1 = self.W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1, device=device)
            Err1 = torch.zeros_like(W1, device=device)
            Hinv1 = Hinv_chol[i1:i2, i1:i2]

            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                # we actually use the column, as the naming is not consistent
                # between GPTQ (and their implementation) and the Dettmers 2022 paper
                # which they refer to
                q = self.quantizer.quantise_with_uncertainty(
                    w.unsqueeze(1), d**2, col=(i + i1)
                ).x.flatten()
                Q1[:, i] = q

                err1 = (w - q) / d
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
                Err1[:, i] = err1

            Q[:, i1:i2] = Q1

            self.W[:, i2:] -= Err1.matmul(Hinv_chol[i1:i2, i2:])

        # self.hessian.set_weight(
        #     Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
        # )
        self.hessian.set_weight(Q.to(self.layer.weight.data.dtype))


class GptqLayerNoUpdate(GptqLayer):
    def fasterquant(self, blocksize=None):
        device = self.W.device
        self.quantizer.find_params(self.W, weight=True)
        Q = torch.zeros_like(self.W, device=device)
        for i in range(0, self.columns):
            d = self.hessian.H[i, i]
            w = self.W[:, i]
            q = self.quantizer.quantise_with_uncertainty(
                w.unsqueeze(1), d, col=i
            ).x.flatten()
            Q[:, i] = q
        self.hessian.set_weight(Q.to(self.layer.weight.data.dtype))
