import math
import time
import copy

import torch
import torch.nn as nn
import transformers

from quant import *

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

class SparseLLAMA:

    def __init__(self, layer, args):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0
        self.args = args

    def add_batch(self, inp, nsamples, out, blocksize=1024):
        eps = 1e-08
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        inp = (inp - inp.mean()) * torch.rsqrt(inp.matmul(inp.t()).mean() + eps)
        self.H += inp.matmul(inp.t())

    def proximal_gradient_descend_v7(self, x, w, b, quantizer):
        if quantizer is not None:
            self.quantizer = quantizer
        lr = self.args.step
        val = x.clone()

        beta1 = 0.9
        beta2 = 0.99
        eps = 1e-8
        m = torch.zeros_like(b)
        v = torch.zeros_like(b)
        c = torch.zeros_like(b)

        w_diag = torch.diag(w)
        w_diag = w_diag.reshape(w_diag.shape[0], 1)

        v1 = torch.ones_like(b)
        v2 = torch.zeros_like(b)

        alpha = 0.1
        rou = 1.4
        if self.args.prunen == 0:
            total_iters = round(self.args.sparsity / self.args.prune_ratio_per_iter)
        else:
            k = int(val.shape[0] / self.args.prunem)
            total_iters = round(val.shape[0] / (k * self.args.prunem))
        optim_iters = 20
        vmask = (torch.zeros_like(val) == 1)
        for index in range(total_iters):
            q = val.abs() * w.norm(p=2, dim=0).reshape(val.shape[0], 1)
            x_square = torch.mul(val, val)
            p = torch.mul(w_diag, x_square)
            p_grad = 2 * torch.matmul(w, val) + b
            p = p - torch.mul(p_grad, val)
            q = torch.mul(q, p)

            if index > 0:
                q = (1 - alpha) * q + alpha * last_q
            last_q = q

            if self.args.prunen != 0:
                # structured n:m sparsity
                for ii in range(0, k * self.args.prunem, self.args.prunem):
                    idx = index * self.args.prunem * k + ii
                    tmp = q[idx : (idx + self.args.prunem), :].float()
                    vmask.scatter_(0, idx + torch.topk(tmp, self.args.prunen, dim=0, largest=False)[1], True)
                val[vmask] = 0.0
                mask = torch.ne(val, 0.0)
            else:
                _, topk_index = torch.topk(q.flatten().abs(),
                                           int(q.flatten().shape[0] * (index + 1) * self.args.prune_ratio_per_iter),
                                           largest=False)
                r = val.flatten().clone()
                r[topk_index] = 0.0
                val = r.reshape(val.shape)
                mask = torch.ne(val, 0.0)


            if hasattr(self, 'quantizer'):
                val = quantize(val.t(),
                               self.quantizer.scale,
                               self.quantizer.zero,
                               self.quantizer.maxq).t()
 
            for k in range(optim_iters):
                grad = 2 * torch.matmul(w, val) + b
                c = beta1 * m + (1 - beta1) * grad
                m = beta2 * m + (1 - beta2) * grad
                
                v1 = rou * torch.sign(c) + (2.0 - rou) * grad
                val = (1.0 - 0.01 * lr) * val - lr * 1.0 / (k + 1.0) * v1
                val = torch.mul(mask, val)
                        
            pgrad = 2 * torch.matmul(w, val) + b
            variable = torch.zeros_like(val)
            m = torch.zeros_like(b)
            for k in range(optim_iters):
                grad = 2 * torch.matmul(w, variable) + pgrad
                c = beta1 * m + (1 - beta1) * grad
                m = beta2 * m + (1 - beta2) * grad

                v1 = rou * torch.sign(c) + (2.0 - rou) * grad
                variable = (1.0 - 0.01 * lr) * variable - lr * 0.1 / (k + 1.0) * v1
                variable = torch.mul(mask, variable)

            val = val + variable
        return val

    def fastprune(self):
        # the reference weight matrixs
        W = self.layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        W = W.float()

        if hasattr(self, 'quantizer'):
            if not self.quantizer.ready():
                self.quantizer.find_params(W, weight=True)

        H = self.H
        del self.H 
        dead = (torch.diag(H) == 0)
        H[dead, dead] = 1
        W[:, dead] = 0

        damp = 0.01 * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.dev)
        H[diag, diag] += damp

        rows = W.shape[0]
        cols = W.shape[1]

        H = torch.linalg.cholesky(H, upper=False)
        H = H.t()
        b = H.matmul((-2) * W.t())

        for index in range(0, rows, self.args.interval): # rows: 4096
            length = min(rows - index, self.args.interval)
            x = W[index:(index + self.args.interval), :].t()

            quantizer = None
            if hasattr(self, 'quantizer'):
                quantizer = copy.deepcopy(self.quantizer)
                quantizer.find_params(W[index:(index + self.args.interval), :], weight=True)
            val = self.proximal_gradient_descend_v7(x, H,
                                                    b[:, index:(index + self.args.interval)],
                                                    quantizer) # [4096, 2048]
            W[index:(index + self.args.interval), :] = val.t().clone()
        print(" W.sparsity: {}, W.max: {} \n".format(
                torch.mean((W == 0).float()), W.abs().max()))
        self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)

    def free(self):
        self.H = None
        torch.cuda.empty_cache()
