import math
import time
import tqdm
import torch
import torch.nn as nn
import utils
import quant_utils
import logging
import functools

from quant_utils import Quantizer


torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

class DuoGPT:

    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0
        self.fp_inp = []
        self.dXdXT = torch.zeros(self.columns, device=self.dev) #! added for DuoGPT

    def add_batch(self, inp, out):

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if len(inp.shape) == 3:
            inp = inp.reshape((-1, inp.shape[-1]))
        inp = inp.t()
        
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.dXXT *= self.nsamples / (self.nsamples + tmp)
        self.dXdXT *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = math.sqrt(1 / self.nsamples) * inp.float()
        self.H += inp.matmul(inp.t())

        #! DuoGPT's feature
        dX = self.fp_inp[0].float() * math.sqrt(1 / self.nsamples) - inp
        self.dXXT += dX.matmul(inp.t())
        self.dXdXT += torch.sum(dX**2, dim=1)
        
        del self.fp_inp[0]

    def fasterprune(
            self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01, actorder=False, alpha=0.125, args=None,
    ):
        W = self.layer.weight.data.clone()
        W = W.float()

        if hasattr(self, 'quantizer'):
            if not self.quantizer.ready():
                self.quantizer.find_params(W, weight=True)

        H = self.H
        del self.H
        dXdXT = self.dXdXT
        del self.dXdXT

        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0
        # self.dXXT[:, dead] = 0 #! Does not require this, guarantee to be 0 for dead entry.

        if actorder:
            if args.block_act: 
                global_sorted_indices = torch.argsort(torch.diag(H), descending=True)
                perm = torch.zeros(self.columns, dtype=torch.int, device=self.dev)
                num_blocks = self.columns // blocksize
                for i, global_idx in enumerate(global_sorted_indices):
                    block_num = (i // args.act_blocksize) % num_blocks
                    offset_within_block = (i // (args.act_blocksize * num_blocks)) * args.act_blocksize
                    pos_within_block = i % args.act_blocksize
                    new_pos = block_num * blocksize + offset_within_block + pos_within_block
                    perm[new_pos] = global_idx
            else:
                perm = torch.argsort(torch.diag(H), descending=True) if not args.dxxt_permutation else torch.argsort(torch.diag(H) + torch.diag(torch.abs(self.dXXT)), descending=True)
            W = W[:, perm]
            H = H[perm][:, perm]
            self.dXXT = self.dXXT[perm][:, perm]
            dXdXT = dXdXT[perm] #! New for DuoGPT
            invperm = torch.argsort(perm)
            

        Q = torch.zeros_like(W)

        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.dev)
        H[diag, diag] += damp
        Hinv = torch.linalg.cholesky(H)
        Hinv = torch.cholesky_inverse(Hinv)
        Hinv = torch.linalg.cholesky(Hinv, upper=True)

        #! V2 implementation
        mask_triangle = torch.ones_like(Hinv).triu_(diagonal=1)
        dXXTL = self.dXXT @ Hinv.T
        P = alpha * (dXXTL * mask_triangle) @ Hinv
        dXXTL2 = torch.sum((dXXTL*mask_triangle)**2, dim=1)
        lP_COE4 = 2*torch.diag(dXXTL)/torch.diag(Hinv)
        del self.dXXT
        

        mask = None #! Add the masks for pruning weights

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            #! focus on the current block
            W1 = W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Hinv1 = Hinv[i1:i2, i1:i2]
            P1 = P[i1:i2, i1:i2] #! V2

            Lp_coe2 = dXdXT[i1:i2].unsqueeze(0)
            Lp_coe3 = dXXTL2[i1:i2].unsqueeze(0)
            Lp_coe4 = lP_COE4[i1:i2].unsqueeze(0)

            #! Mask selection
            if prunen == 0:
                if mask is not None:
                    mask1 = mask[:, i1:i2]
                else:
                    #! Fold the division into a multiplication of scale will result into small deviation.
                    #! Scale by alpha to align with the updates.
                    tmp = W1 ** 2 / torch.diag(Hinv1).reshape((1, -1)) ** 2 + W1**2 * (Lp_coe2-Lp_coe3+Lp_coe4) * alpha

                    thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] #! Keep using this method to align with SparseGPT.
                    mask1 = tmp <= thresh
            else:
                mask1 = torch.zeros_like(W1) == 1


            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                if prunen != 0 and i % prunem == 0:
                    tmp = W1[:, i:(i + prunem)] ** 2 / torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1)) ** 2 + W1[:, i:(i + prunem)]**2 * (Lp_coe2-Lp_coe3+Lp_coe4)[:,i:(i + prunem)] * alpha
                    mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)


                q = w.clone()
                q[mask1[:, i]] = 0
                
                if hasattr(self, 'quantizer'):
                    q = self.quantizer.quantize(q.unsqueeze(1)).flatten()

                Q1[:, i] = q

                err1 = (w - q) / d
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))#! V2
                Err1[:, i] = err1

            Q[:, i1:i2] = Q1

            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:])#! lazy batch update

        torch.cuda.synchronize()

        if actorder:
            Q = Q[:, invperm]

        self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
        if torch.any(torch.isnan(self.layer.weight.data)):
            raise ValueError('NaN in weights')

    def free(self):
        self.H = None
        # self.Losses = None
        self.dXXT = None
        self.dXdXT = None
        torch.cuda.empty_cache()
        utils.cleanup_memory(verbos=False)


class FPInputsCache:
    """
    class for saving the full-precision output in each layer.
    """
    def __init__(self, sequential):
        self.fp_cache = {}
        self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3]
        for name in self.names:
            self.fp_cache[name] = []
        self.handles = []

    def cache_fp_input(self, m, inp, out, name):
        inp = inp[0].detach()
        if len(inp.shape) == 3:
            inp = inp.reshape((-1, inp.shape[-1]))
        self.fp_cache[name] += [inp.t()]
        

    def add_hook(self, full):
        for name in self.names:
            self.handles.append(
                full[name].register_forward_hook(
                    functools.partial(self.cache_fp_input, name=name)
                )
            )

    def clear_hook(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        torch.cuda.empty_cache()

    def clear_cache(self):
        for name in self.names:
            self.fp_cache[name] = []


@torch.no_grad()
def duogpt_fwrd(model, dataloader, dev, args):

    logging.info('-----DuoGPT Calibration Start-----')

    if 'opt' in args.model:
        opt_type = True
        llama_type = False
    elif 'llama' in args.model.lower() or 'qwen' in args.model.lower() or 'mistral' in args.model.lower():
        llama_type = True
        opt_type = False
    else:
        raise ValueError(f'Unknown model {args.model}')

    use_cache = model.config.use_cache
    model.config.use_cache = False


    if opt_type:
        layers = model.model.decoder.layers
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
        if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
        if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
    elif llama_type:
        layers = model.model.layers
        model.model.norm = model.model.norm.to(dev)
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.rotary_emb = model.model.rotary_emb.to(dev)
    
    
    layers[0] = layers[0].to(dev)
    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )

    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            if llama_type:
                cache['position_embeddings'] = kwargs['position_embeddings']
            # cache['position_embeddings'] = kwargs['position_embeddings']
            raise ValueError

    layers[0] = Catcher(layers[0]) #! Add catcher at the first transformer layer.
    for batch in dataloader:
        try:
            model(batch[0].to(dev)) #! This will catch the output states from the embedding.
        except ValueError:
            pass
    layers[0] = layers[0].module #! Remove the catcher module.
    layers[0] = layers[0].cpu() #! Put the layer back to cpu.

    if opt_type:
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
        if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.cpu()
        if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.cpu()
    elif llama_type:
        model.model.embed_tokens = model.model.embed_tokens.cpu()
        model.model.norm = model.model.norm.cpu()
        model.model.rotary_emb = model.model.rotary_emb.cpu()
        torch.cuda.empty_cache()
        position_embeddings = cache['position_embeddings']


    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']

    quantizers = {}

    if opt_type:
        sequential = [
            ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
            ['self_attn.out_proj.module'],
            ['fc1.module'],
            ['fc2.module']
        ]
    else:
        sequential = [
            ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
            ['self_attn.o_proj.module'],
            ['mlp.up_proj.module', 'mlp.gate_proj.module'],
            ['mlp.down_proj.module']
        ]

    fp_inputs_cache = FPInputsCache(sequential)
    fp_inps = inps.clone()

    for i in tqdm.tqdm(range(len(layers))):
        layer = layers[i].to(dev)
        full = quant_utils.find_layers(layer, layers=[torch.nn.Linear])

        fp_inputs_cache.add_hook(full)#! V2

        #! To be safe, force to disable the prunning of activation here before collecting FP x.
        if args.enable_ap_calibration:
            quant_utils.disable_act_sparsity(layer)
        
        #! getting FP X
        for j in range(args.nsamples):
            if opt_type:
                fp_inps[j] = layer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask)[0]

            else:
                fp_inps[j] = layer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
        fp_inputs_cache.clear_hook()

        if args.enable_ap_calibration:
            quant_utils.enable_act_sparsity(layer, args.a_sparsity)#! turn on activation sparsity

        for names in sequential:
            subset = {n: full[n] for n in names}

            gpts = {}
            for name in subset:
                gpts[name] = DuoGPT(subset[name])
                gpts[name].fp_inp = fp_inputs_cache.fp_cache[name] #! FP X
                if args.w_bits < 16:
                    gpts[name].quantizer = Quantizer()
                    gpts[name].quantizer.configure(
                        args.w_bits, perchannel=True, sym=False, mse=False, grouprows=128
                    )
            
            def add_batch(name):
                def tmp(_, inp, out):
                    gpts[name].add_batch(inp[0].data, out.data)
                return tmp

            #! only calculate the H and dXXT for one of the parallel block, then duplicate.
            first_module_name = list(subset.keys())[0]
            handle = subset[first_module_name].register_forward_hook(add_batch(first_module_name))

            for j in range(args.nsamples):
                if opt_type:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
                else:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask,
                            position_embeddings=position_embeddings)[0]
            handle.remove()

            # copy H and dXXT, and dXdXT
            for name in subset:
                if name != first_module_name:
                    gpts[name].H = gpts[first_module_name].H
                    gpts[name].dXXT = gpts[first_module_name].dXXT
                    gpts[name].dXdXT = gpts[first_module_name].dXdXT #! New for DuoGPT

            for name in subset:
                sparsity = args.sparsity
                gpts[name].fasterprune(
                    sparsity,
                    prunen=args.prunen,
                    prunem=args.prunem,
                    percdamp=args.percdamp,
                    blocksize=args.blocksize,
                    actorder=args.act_order,
                    alpha = args.scale_alpha,
                    args=args
                )
                gpts[name].free()
        #! For generating the outputs for the next layer.
        for j in range(args.nsamples):
            if opt_type:
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
            else:
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask,
                            position_embeddings=position_embeddings)[0]
        fp_inputs_cache.clear_cache()
        layers[i] = layer.cpu()
        del layer
        del gpts
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache
    utils.cleanup_memory(verbos=False)

    logging.info('-----DuoGPT Calibration Done-----\n')

    return quantizers