# What do we want to do here?
# Try to reimplement the GPTQ method, but without all the tricks they used to
# speed up the implementation. This essentially boils down to
# the description given in https://arxiv.org/pdf/2208.11580.pdf
from ._interfaces import Quantiser
import torch.nn
import copy


class Gptq:
    def __init__(self, quantizer: Quantiser, dampening_percent: float = 0.01) -> None:
        self.quantizer = quantizer
        self.dampening_percent = dampening_percent

    def quantize(self, layer: torch.nn.Module, X: torch.Tensor):
        layer_copy = copy.deepcopy(layer)
        self.quantize_(layer_copy, X)
        return layer_copy

    def quantize_(self, layer: torch.nn.Module, X: torch.Tensor) -> None:
        """Quantize the weights of a linear layer using GPTQ, in-place"""
        if not isinstance(layer, torch.nn.Linear):
            raise ValueError("Only Linear layers are supported")
        if len(X.shape) != 2:
            raise ValueError("x must have shape (batch_size, features)")

        self.quantizer.find_params(layer.weight.data)
        W = layer.weight.data.clone()

        H = 2 * X.t() @ X  # row-wise Hessian, the scaling is due to GPTQ
        # in GPTQ, they divide by  X.shape[0]

        # deleting the dead neurons
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0

        dampening = self.dampening_percent * torch.mean(torch.diag(H))
        H += torch.eye(H.shape[0]) * dampening

        Hinv = torch.linalg.inv(H)

        for idx in range(W.shape[0]):
            w_row = W[idx, :]
            # we use the insight from GPTQ that the column quantisation order does not matter

            Hinv_i = (
                Hinv.clone()
            )  # we need to copy the Hinv matrix, as we will modify it depending on pruning
            for col in range(W.shape[1]):
                w = w_row[col]
                # eq 7, Optimal Brain Compression
                w_quant = self.quantizer.quantise(w).x  # type: ignore
                delta_w = -(w - w_quant) / Hinv_i[col, col] * Hinv_i[:, col]
                self.drop_index_inverse_(Hinv_i, col)
                w_row += delta_w
                layer.weight.data[idx, col] = w_quant

    @staticmethod
    def drop_index_inverse_(Hinv, idx):
        """Drop a row from the inverse Hessian"""
        Hinv -= torch.outer(Hinv[idx, :], Hinv[:, idx]) / Hinv[idx, idx]
