import time 
import torch 
import torch.nn as nn 
import utils
import quant_utils
import logging
import tqdm

#! The implemenations are modified from the Wanda's github repo:
#! https://github.com/locuslab/wanda

# Define WrappedGPT class
class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.layer_id = layer_id 
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples+tmp)
        self.nsamples += tmp

        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples


def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
    thres_cumsum = sum_before * alpha 
    sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
    thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
    W_mask = (W_metric <= thres)
    cur_sparsity = (W_mask==True).sum() / W_mask.numel()
    return W_mask, cur_sparsity


@torch.no_grad()
def prune_wanda(model, dataloader, dev, args):

    logging.info('-----Wanda Calibration Start-----')
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.norm = model.model.norm.to(dev)
    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            cache['position_embeddings'] = kwargs['position_embeddings']
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.norm = model.model.norm.cpu()
    model.model.rotary_emb = model.model.rotary_emb.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]
    position_embeddings = cache['position_embeddings']

    sequential = [
        ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
        ['self_attn.o_proj.module'],
        ['mlp.up_proj.module', 'mlp.gate_proj.module'],
        ['mlp.down_proj.module']
    ]


    for i in tqdm.tqdm(range(len(layers))):
        layer = layers[i].to(dev)
        full = quant_utils.find_layers(layer,layers=[torch.nn.Linear])

        for names in sequential:
            subset = {n: full[n] for n in names}
            wandas = {}
            for name in subset:
                wandas[name] = WrappedGPT(subset[name])

            def add_batch(name):
                def tmp(_, inp, out):
                    wandas[name].add_batch(inp[0].data, out.data)
                return tmp
            
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(args.nsamples):                    
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask,
                                position_embeddings=position_embeddings)[0]
            for h in handles:
                h.remove()

            for name in subset:
                sparsity = args.sparsity

                W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wandas[name].scaler_row.reshape((1,-1)))
                W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
                if args.prunen != 0:
                    # structured n:m sparsity
                    for ii in range(W_metric.shape[1]):
                        if ii % args.prunem == 0:
                            tmp = W_metric[:,ii:(ii+args.prunem)].float()
                            W_mask.scatter_(1,ii+torch.topk(tmp, args.prunen, dim=1, largest=False)[1], True)
                else:
                    sort_res = torch.sort(W_metric, dim=-1, stable=True)
                    indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity)]
                    W_mask.scatter_(1, indices, True)
                    #! We only compare to the unstructured pruning version of Wanda
                    #! We do not incldue the use_variant flag in the original codes.

                subset[name].weight.data[W_mask] = 0

        for j in range(args.nsamples): #! This run is the running after the prunned has been done. For generating the outputs for the next layer.
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]

        layers[i] = layer.cpu()
        del layer
        del wandas
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    model.config.use_cache = use_cache
    utils.cleanup_memory(verbos=True)

    logging.info('-----Wanda Calibration Done-----\n')
