import math
import time
import tqdm
import torch
import torch.nn as nn
import utils
import quant_utils
import logging
import model_utils

from quant_utils import Quantizer

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

#! The implemenations are modified from the SparseGPT's github repo:
#! https://github.com/IST-DASLab/sparsegpt

class SparseGPT:

    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0

    def add_batch(self, inp, out):
        
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0] #! tmp is the batch size
        
        if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        #! Updating the running average
        #! H += H * (old_total/new_total) + (2XX^T)/new_total
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = math.sqrt(2 / self.nsamples) * inp.float() #! equivalent to OBC's implementation, just scales the input before the next line.
        self.H += inp.matmul(inp.t())

    def fasterprune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01, actorder=False):
        
        W = self.layer.weight.data.clone()
        W = W.float()

        if hasattr(self, 'quantizer'):
            if not self.quantizer.ready():
                self.quantizer.find_params(W, weight=True)

        #! remove the neurons where the hessian is dead
        H = self.H
        del self.H
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0

        if actorder:
            perm = torch.argsort(torch.diag(H), descending=True)
            W = W[:, perm]
            H = H[perm][:, perm]
            invperm = torch.argsort(perm)

        Losses = torch.zeros(self.rows, device=self.dev)

        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.dev)
        H[diag, diag] += damp
        H = torch.linalg.cholesky(H)
        H = torch.cholesky_inverse(H)
        H = torch.linalg.cholesky(H, upper=True)
        Hinv = H

        mask = None #! Add the masks

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            #! only focus on the current block
            W1 = W[:, i1:i2].clone() #! The weight blocks
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Losses1 = torch.zeros_like(W1)
            Hinv1 = Hinv[i1:i2, i1:i2]


            if prunen == 0:
                if mask is not None:
                    mask1 = mask[:, i1:i2]
                else:
                    tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
                    thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] 
            else:
                mask1 = torch.zeros_like(W1) == 1


            #! Column by column inside the block
            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                if prunen != 0 and i % prunem == 0:
                    tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2
                    mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)

                q = w.clone()
                q[mask1[:, i]] = 0

                if hasattr(self, 'quantizer'):
                    q = self.quantizer.quantize(q.unsqueeze(1)).flatten()
                
                Q1[:, i] = q #! Q1 is kinda like the masked version of weights.
                Losses1[:, i] = (w - q) ** 2 / d ** 2

                err1 = (w - q) / d #! d = L = sqrt(H^-1)
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) #! Update the weights
                Err1[:, i] = err1 #! Set part of the blockwise error

            W[:, i1:i2] = Q1
            Losses += torch.sum(Losses1, 1) / 2
            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) #! lazy batch update

        torch.cuda.synchronize()

        if actorder:
            W = W[:, invperm]

        self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) #! Clean back the format of weight
        if torch.any(torch.isnan(self.layer.weight.data)):
            raise ValueError('NaN in weights')

    def free(self):
        self.H = None
        self.Losses = None
        torch.cuda.empty_cache()
        utils.cleanup_memory(verbos=False)
        
        
@torch.no_grad()
def sparsegpt_fwrd(model, dataloader, dev, args):

    logging.info('-----SparseGPT Calibration Start-----')
    if args.debug:
        return
    
    if 'opt' in args.model:
        opt_type = True
        llama_type = False
    elif 'llama' in args.model.lower() or 'qwen' in args.model.lower() or 'mistral' in args.model.lower():
        llama_type = True
        opt_type = False
    else:
        raise ValueError(f'Unknown model {args.model}')
    
    use_cache = model.config.use_cache
    model.config.use_cache = False
    
    if opt_type:
        layers = model.model.decoder.layers
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
        if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
        if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
    elif llama_type:
        layers = model.model.layers
        model.model.norm = model.model.norm.to(dev)
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.rotary_emb = model.model.rotary_emb.to(dev)

    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            if llama_type:
                cache['position_embeddings'] = kwargs['position_embeddings']
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()


    if opt_type:
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
        if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.cpu()
        if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.cpu()
    elif llama_type:
        model.model.embed_tokens = model.model.embed_tokens.cpu()
        model.model.norm = model.model.norm.cpu()
        model.model.rotary_emb = model.model.rotary_emb.cpu()
        torch.cuda.empty_cache()
        position_embeddings = cache['position_embeddings']

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]

    quantizers = {}
    if opt_type:
        sequential = [
            ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
            ['self_attn.out_proj.module'],
            ['fc1.module'],
            ['fc2.module']
        ]
    else:
        sequential = [
            ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
            ['self_attn.o_proj.module'],
            ['mlp.up_proj.module', 'mlp.gate_proj.module'],
            ['mlp.down_proj.module']
        ]

    for i in tqdm.tqdm(range(len(layers))):
        layer = layers[i].to(dev)
        full = quant_utils.find_layers(layer,layers=[torch.nn.Linear])

        for names in sequential:
            subset = {n: full[n] for n in names}

            gpts = {}
            for name in subset:
                gpts[name] = SparseGPT(subset[name])
                if args.w_bits < 16:
                    gpts[name].quantizer = Quantizer()
                    gpts[name].quantizer.configure(
                        args.w_bits, perchannel=True, sym=False, mse=False, grouprows=128
                    )

            def add_batch(name):
                def tmp(_, inp, out):
                    gpts[name].add_batch(inp[0].data, out.data)
                return tmp
            
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(args.nsamples):
                if opt_type:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
                else:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask,
                                position_embeddings=position_embeddings)[0]
            for h in handles:
                h.remove()

            for name in subset:
                sparsity = args.sparsity
                gpts[name].fasterprune(
                    sparsity,
                    prunen=args.prunen,
                    prunem=args.prunem,
                    percdamp=args.percdamp,
                    blocksize=args.blocksize,
                    actorder=args.act_order
                )
                gpts[name].free()
        
        for j in range(args.nsamples): #! This run is the running after the prunned has been done. For generating the outputs for the next layer.
            if opt_type:
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
            else:
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]

        layers[i] = layer.cpu()
        del layer
        del gpts
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache
    utils.cleanup_memory(verbos=True)

    logging.info('-----SparseGPT Calibration Done-----\n')
    return quantizers
