import torch
import time
import math

class SparseGPT:
    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0

    def add_batch(self, inp, out=None):
        if len(inp.shape) == 3:
            inp = inp.reshape(-1, inp.shape[-1])
        elif len(inp.shape) == 2:
            pass
        else:
            raise ValueError(f"Unexpected input shape: {inp.shape}")
        if inp.shape[1] != self.columns:
            raise ValueError(f"Input feature dimension {inp.shape[1]} does not match expected {self.columns}")

        tmp = inp.shape[0]
        inp = inp.t()  # (features, N)

        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        self.H += inp.matmul(inp.t())

    def sparseprune(self, sparsity=0.5, blocksize=128, percdamp=0.01, actorder=False):
        W = self.layer.weight.data.clone().float()
        tick = time.time()

        H = self.H
        del self.H
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0

        if actorder:
            perm = torch.argsort(torch.diag(H), descending=True)
            W = W[:, perm]
            H = H[perm][:, perm]
            invperm = torch.argsort(perm)

        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.dev)
        H[diag, diag] += damp
        H = torch.linalg.cholesky(H + 1e-2 * torch.eye(H.shape[0], device=H.device))
        H = torch.cholesky_inverse(H)
        Hinv = H

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            Hinv1 = Hinv[i1:i2, i1:i2]
            W1 = W[:, i1:i2]
            for i in range(i2 - i1):
                w = W1[:, i]
                d = Hinv1[i, i]
                scores = (w / d) ** 2
                k = int(scores.numel() * (1 - sparsity))
                if k < 1:
                    continue
                threshold = torch.topk(scores, k)[0][-1]
                mask = scores >= threshold
                W1[:, i] = w * mask.float()

        if actorder:
            W = W[:, invperm]

        self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)

    def free(self):
        self.H = None
        torch.cuda.empty_cache()
