import time 
import heapq 
import logging
import torch 
import torch.nn as nn 
from .sparsegpt import SparseGPT 
from .layerwrapper import WrappedGPT
from .data import get_loaders 
from .admm import AdmmPruner

from .ablate import AblateGPT 
from slicegpt.model_adapter import LayerAdapter, ModelAdapter, rot_mask_Linear
from slicegpt.rotate import get_layer0_inputs, get_signals

def find_layers(module, layers=[nn.Linear], name='', excludes=[]):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    for exclude in excludes:
        if exclude in name:
            return {}
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1, excludes=excludes,
        ))
    return res

def check_sparsity(model):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    layers = model.model.layers
    count = 0 
    total_params = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer, layers=[rot_mask_Linear, nn.Linear])

        sub_count = 0
        sub_params = 0
        for name in subset:
            W = subset[name].weight.data
            count += (W==0).sum().item()
            total_params += W.numel()

            sub_count += (W==0).sum().item()
            sub_params += W.numel()

        print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")

    model.config.use_cache = use_cache 
    return float(count)/total_params 

def prepare_calibration_input(model, dataloader, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((len(dataloader), model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(device))
        except ValueError:
            pass 
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask

def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
    thres_cumsum = sum_before * alpha 
    sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
    thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
    W_mask = (W_metric <= thres)
    cur_sparsity = (W_mask==True).sum() / W_mask.numel()
    return W_mask, cur_sparsity

def prune_magnitude(args, model_adapter: ModelAdapter, dataloader, sparsity, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    layers = model_adapter.get_layers()
    W_masks = []

    for i in range(len(layers)):
        W_masks.append({})
        layer = layers[i].layer
        
        subset = find_layers(layer, layers=[rot_mask_Linear, nn.Linear])
        # subset.update({'mlp_shortcut_Q': layer.mlp_shortcut_Q, 'attn_shortcut_Q': layer.attn_shortcut_Q})

        for name in subset:        
            if isinstance(subset[name], nn.Linear) or isinstance(subset[name], rot_mask_Linear):
                W = subset[name].weight.data
            else:
                W = subset[name]
            W_metric = torch.abs(W)
            if prune_n != 0:
                W_mask = (torch.zeros_like(W)==1)
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu()
                W_mask = (W_metric<=thresh)

            W[W_mask] = 0
            W_masks[i][name] = W_mask # return the mask
            
    return W_masks

@torch.no_grad
def prune_magnitude_subset(args, model_adapter: ModelAdapter, dataloader, subset, sparsity_ratio, apply_mask=True, device=torch.device("cuda:0"), prune_n=0, prune_m=0):    
    W_masks = {}
    
    for name in subset:
        W = subset[name].weight.data
        W_metric = torch.abs(subset[name].weight.data)
        if prune_n != 0:
            W_mask = (torch.zeros_like(W)==1)
            for ii in range(W_metric.shape[1]):
                if ii % prune_m == 0:
                    tmp = W_metric[:,ii:(ii+prune_m)].float()
                    W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
        else:
            thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity_ratio)].cpu()
            W_mask = (W_metric<=thresh)

        subset[name].weight.data[W_mask] = 0  ## set weights to zero 
        W_masks[name] = W_mask # return the mask

    torch.cuda.empty_cache()
    
    return W_masks

@torch.no_grad
def prune_wanda_subset(args, model_adapter: ModelAdapter, dataloader, subset, sparsity_ratio, apply_mask=True, device=torch.device("cuda:0"), prune_n=0, prune_m=0):    
    W_masks = {}
    model = model_adapter.model

    wrapped_layers = {}
    for name in subset:
        wrapped_layers[name] = WrappedGPT(subset[name])

    def add_batch(name):
        def tmp(_, inp, out):
            wrapped_layers[name].add_batch(inp[0].data, out.data)
        return tmp

    handles = []
    outs = []
    for name in wrapped_layers:
        handles.append(subset[name].register_forward_hook(add_batch(name)))
    for batch in dataloader:
        outs.append(model(batch['input_ids'].to(device))) 
    for h in handles:
        h.remove()

    for name in subset:
        # print(f"pruning layer {i} name {name}")
        W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
        
        # W_metric = torch.abs(subset[name].weight.data)

        W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
        if prune_n != 0:
            # structured n:m sparsity
            for ii in range(W_metric.shape[1]):
                if ii % prune_m == 0:
                    tmp = W_metric[:,ii:(ii+prune_m)].float()
                    W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
        else:
            sort_res = torch.sort(W_metric, dim=-1, stable=True)

            # unstructured pruning
            indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity_ratio)]
            W_mask.scatter_(1, indices, True)

        subset[name].weight.data[W_mask] = 0  ## set weights to zero 
        W_masks[name] = W_mask # return the mask

    torch.cuda.empty_cache()
    
    return W_masks
    
    
@torch.no_grad
def prune_wanda(args, model_adapter: ModelAdapter, dataloader, sparsity_ratio, apply_mask=True, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    # logging.info(f"Pruning: wanda")
    
    model_adapter.model.eval()
    model = model_adapter.model

    inps, outs, attention_masks, position_ids = [], [], [], []
    
    for batch, _ in dataloader:
        # import pdb; pdb.set_trace()
        inp_batch, _, save_kwargs = get_layer0_inputs(model_adapter, batch)
        inps.append(inp_batch.to(device=device))
        outs.append(torch.zeros_like(inp_batch))
        if save_kwargs['attention_mask'] is None:
            attention_masks.append(save_kwargs['attention_mask'])
        else:
            attention_masks.append(save_kwargs['attention_mask'].to(device=device))
        if save_kwargs['position_ids'] is None:
            position_ids.append(save_kwargs['position_ids'])
        else:
            position_ids.append(save_kwargs['position_ids'].to(device=device))
            
    layers = model.model.layers
    W_masks = []
    for i in range(len(layers)):
        W_masks.append({})
        layer = layers[i]
        subset = find_layers(layer, layers=[rot_mask_Linear, nn.Linear])
        # subset.update({'mlp_shortcut_Q': layer.mlp_shortcut_Q, 'attn_shortcut_Q': layer.attn_shortcut_Q})

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(args.cal_nsamples):
            outs[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0]
        for h in handles:
            h.remove()

        for name in subset:
            # print(f"pruning layer {i} name {name}")
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))

            W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
            if prune_n != 0:
                # structured n:m sparsity
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                sort_res = torch.sort(W_metric, dim=-1, stable=True)

                # unstructured pruning
                indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity_ratio)]
                W_mask.scatter_(1, indices, True)

            subset[name].weight.data[W_mask] = 0  ## set weights to zero 
            W_masks[i][name] = W_mask # return the mask

        for j in range(args.cal_nsamples):
            outs[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0]
        inps, outs = outs, inps

    torch.cuda.empty_cache()
    
    return W_masks

@torch.no_grad
def prune_slice(args, model_adapter: ModelAdapter, sparsity, device=torch.device("cuda:0")):
    layers = model_adapter.get_layers()
    W_masks = []
    
    slice_dim = int((1 - sparsity) * model_adapter.hidden_size)
    
    # for embed in model_adapter.get_embeddings():
    #     W = embed.weight.data
    #     W_mask = (torch.zeros_like(W)==1)
    #     W_mask[:, slice_dim:] = True
        
    #     W[W_mask] = 0
        # embed.weight.data[:, slice_dim:] = 0

    for i, layer_adapter in enumerate(layers):
        
        W_masks.append({})
        layer = layer_adapter.layer   
        
        for name, module in layer_adapter.get_attention_inputs_dict().items():
            W = module.weight.data
            W_mask = (torch.zeros_like(W)==1) # all False
            W_mask[:, slice_dim:] = True
            
            W[W_mask] = 0
            W_masks[i][name] = W_mask # return the mask
            
            # module.weight.data[:, slice_dim:] = 0
            
        for name, module in layer_adapter.get_attention_outputs_dict().items():
            W = module.weight.data
            W_mask = (torch.zeros_like(W)==1)
            W_mask[slice_dim:, :] = True
            
            W[W_mask] = 0
            W_masks[i][name] = W_mask # return the mask
            
            # module.weight.data[slice_dim:, :] = 0
            # module.bias.data[slice_dim:] = 0
            
        for name, module in layer_adapter.get_mlp_inputs_dict().items():
            W = module.weight.data
            W_mask = (torch.zeros_like(W)==1)
            W_mask[:, slice_dim:] = True
            
            W[W_mask] = 0
            W_masks[i][name] = W_mask # return the mask
            
            #  module.weight.data[:, slice_dim:] = 0
            
        for name, module in layer_adapter.get_mlp_outputs_dict().items():
            W = module.weight.data
            W_mask = (torch.zeros_like(W)==1)
            if i < len(layers) - 1:
                W_mask[slice_dim:, :] = True
            
            W[W_mask] = 0
            W_masks[i][name] = W_mask # return the mask
            
            # module.weight.data[slice_dim:, :] = 0
            # module.bias.data[slice_dim:] = 0
            
        W = layer.attn_shortcut_Q.weight.data
        W_mask = (torch.zeros_like(W)==1)
        W_mask[slice_dim:, :] = True
        W_mask[:, slice_dim:] = True
        W[W_mask] = 0
        W_masks[i]["attn_shortcut_Q"] = W_mask
        
        W = layer.mlp_shortcut_Q.weight.data
        W_mask = (torch.zeros_like(W)==1)
        if i < len(layers) - 1:
            W_mask[slice_dim:, :] = True
        W_mask[:, slice_dim:] = True
        W[W_mask] = 0
        W_masks[i]["mlp_shortcut_Q"] = W_mask
            
    return W_masks

@torch.no_grad
def prune_slice_subset(args, model_adapter: ModelAdapter, dataloader, subset, sparsity_ratio, last=False, apply_mask=True, device=torch.device("cuda:0"), prune_n=0, prune_m=0):    
    W_masks = {}
    slice_dim = int((1 - sparsity_ratio) * model_adapter.hidden_size)
    
    in_name = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'fc1']
    out_name = ['self_attn.out_proj', 'fc2']
    
    for name in subset:
        
        if name in in_name:
            W = subset[name].weight.data
            W_mask = (torch.zeros_like(W)==1)
            W_mask[:, slice_dim:] = True
            W[W_mask] = 0
            W_masks[name] = W_mask # return the mask
            
        elif name in out_name:
            W = subset[name].weight.data
            W_mask = (torch.zeros_like(W)==1)
            if not last:
                W_mask[slice_dim:, :] = True
            W[W_mask] = 0
            W_masks[name] = W_mask # return the mask
            
        else:
            W = subset[name].weight.data
            W_mask = (torch.zeros_like(W)==1)
            if not last:
                W_mask[slice_dim:, :] = True
            W_mask[:, slice_dim:] = True
            W[W_mask] = 0
            W_masks[name] = W_mask # return the mask

    torch.cuda.empty_cache()
    
    return W_masks


@torch.no_grad()
def prune_sparsegpt(args, model_adapter: ModelAdapter, dataloader, sparsity_ratio, device, apply_mask=True, prune_n=0, prune_m=0):
    ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
    model_adapter.model.eval()
    model = model_adapter.model

    layers = model.model.layers

    inps, outs, ignore_masks = [], [], []

    for batch in dataloader:
        inp_batch, _, _ = get_layer0_inputs(model_adapter, batch)
        inps.append(inp_batch.to(device=device))
        outs.append(torch.zeros_like(inp_batch))
        if apply_mask:
            ignore_masks.append(batch["attention_mask"].to(device=device))

    W_masks = []

    for i in range(len(layers)):
        W_masks.append({})
        layer = layers[i]

        subset = find_layers(layer, layers=[nn.Linear])

        gpts = {}
        for name in subset:
            gpts[name] = SparseGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.cal_nsamples):
            outs[j] = layer(inps[j], attention_mask=ignore_masks[j])[0]
        for h in handles:
            h.remove()

        for name in gpts:
            print(i, name)
            print('Pruning ...')

            gpts[name].fasterprune(sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
            gpts[name].free()

        for j in range(args.cal_nsamples):
            outs[j] = layer(inps[j], attention_mask=ignore_masks[j])[0]

        layers[i] = layer 
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    torch.cuda.empty_cache()

@torch.no_grad()
def prune_admm(args, model_adapter: ModelAdapter, dataloader, sparsity_ratio, device, apply_mask=True, prune_n=0, prune_m=0):
    model_adapter.model.eval()
    model = model_adapter.model

    layers = model.model.layers

    inps, outs, ignore_masks = [], [], []

    for batch in dataloader:
        inp_batch, _, _ = get_layer0_inputs(model_adapter, batch)
        inps.append(inp_batch.to(device=device))
        outs.append(torch.zeros_like(inp_batch))
        if apply_mask:
            ignore_masks.append(batch["attention_mask"].to(device=device))

    W_masks = []

    for i in range(len(layers)):
        W_masks.append({})
        layer = layers[i]

        subset = find_layers(layer, layers=[nn.Linear])

        gpts = {}
        for name in subset:
            gpts[name] = AdmmPruner(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.cal_nsamples):
            outs[j] = layer(inps[j], attention_mask=ignore_masks[j])[0]
        for h in handles:
            h.remove()

        for name in gpts:
            print(i, name)
            print('Pruning ...')

            gpts[name].fasterprune(sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.1)
            gpts[name].free()

        for j in range(args.cal_nsamples):
            outs[j] = layer(inps[j], attention_mask=ignore_masks[j])[0]

        layers[i] = layer 
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    torch.cuda.empty_cache()

@torch.no_grad()
def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0):
    ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
    print('Starting ...')
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    if "model.embed_tokens" in model.hf_device_map:
        dev = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            # cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    # position_ids = cache['position_ids']

    print('Ready.')

    for i in range(len(layers)):
        layer = layers[i]
        if f"model.layers.{i}" in model.hf_device_map:
            dev = model.hf_device_map[f"model.layers.{i}"]
            print(f"layer {i} device {dev}")
            inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev)

        subset = find_layers(layer)

        gpts = {}
        for name in subset:
            gpts[name] = AblateGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
        for h in handles:
            h.remove()

        for name in gpts:
            print(i, name)
            print('Pruning ...')

            if args.prune_method == "ablate_wanda_seq":
                prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m)
            elif args.prune_method == "ablate_mag_seq":
                prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m)
            elif "iter" in args.prune_method:
                prune_mask = None 

            gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, 
                                        prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
            gpts[name].free()

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]

        layers[i] = layer 
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
