import time 
import heapq 
import torch 
import torch.nn as nn 
from .sparsegpt import SparseGPT 
from .layerwrapper import WrappedGPT
from .data import get_loaders 

from .ablate import AblateGPT 
from .gptree import GPTree
from lib.eval import eval_ppl
from lib.model_utils import get_layers, get_seqlen

import numpy as np
import os

import matplotlib.pyplot as plt
import seaborn as sns


def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def check_sparsity(model):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    model_class = model.__class__.__name__
    if 'OPT' in model_class:
        layers = model.model.decoder.layers
    elif 'GPTNeo' in model_class:
        layers = model.gpt_neox.layers
    elif 'Qwen' in model_class:
        layers = model.model.layers
    elif 'Llama' in model_class:
        layers = model.model.layers
    count = 0 
    total_params = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        sub_count = 0
        sub_params = 0
        for name in subset:
            W = subset[name].weight.data
            count += (W==0).sum().item()
            total_params += W.numel()

            sub_count += (W==0).sum().item()
            sub_params += W.numel()

        print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")

    model.config.use_cache = use_cache 
    return float(count)/total_params 

def prepare_calibration_input(model, dataloader, device, seqlen: int = 512):
    use_cache = model.config.use_cache
    model.config.use_cache = False

    mc = model.__class__.__name__
    if   "OPT"     in mc: layers = model.model.decoder.layers
    elif "GPTNeoX" in mc: layers = model.gpt_neox.layers
    elif "GPTNeo"  in mc: layers = model.transformer.h
    elif "Qwen"    in mc: layers = model.model.layers
    elif 'Llama'   in mc: layers = model.model.layers
    else: raise ValueError(f"Unsupported model type {mc}")

    if hasattr(model, "hf_device_map"):
        if   "model.embed_tokens" in model.hf_device_map:
            device = model.hf_device_map["model.embed_tokens"]
        elif "transformer.wte" in model.hf_device_map:
            device = model.hf_device_map["transformer.wte"]
        elif "model.tok_embeddings" in model.hf_device_map:
            device = model.hf_device_map["model.tok_embeddings"]

    dtype = next(model.parameters()).dtype
    NS    = len(dataloader)

    inps = torch.zeros((NS, seqlen, model.config.hidden_size), dtype=dtype, device=device)
    outs = torch.zeros_like(inps)
    cache = {"i": 0, "attention_mask": None}

    def capture_hook(module, input, output):
        inps[cache["i"]] = input[0].detach()
        cache["attention_mask"] = None 
        cache["i"] += 1
        raise ValueError("Stop after capture")

    hook_handle = layers[0].register_forward_hook(capture_hook)

    for batch in dataloader:
        input_ids = batch[0][:, :seqlen].to(device)

        try:
            if hasattr(model.config, "model_type") and "qwen" in model.config.model_type.lower():
                attention_mask = torch.ones_like(input_ids)
                position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
                position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

                model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
            else:
                model(input_ids)
        except ValueError:
            pass

    hook_handle.remove()

    attention_mask = cache["attention_mask"]
    model.config.use_cache = use_cache
    return inps, outs, attention_mask

def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
    thres_cumsum = sum_before * alpha 
    sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
    thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
    W_mask = (W_metric <= thres)
    cur_sparsity = (W_mask==True).sum() / W_mask.numel()
    return W_mask, cur_sparsity

def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    model_class = model.__class__.__name__
    if 'OPT' in model_class:
        layers = model.model.decoder.layers
    elif 'GPTNeo' in model_class:
        layers = model.gpt_neox.layers
    elif 'Qwen' in model_class:
        layers = model.model.layers
    elif 'Llama' in model_class:
        layers = model.model.layers

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        for name in subset:
            W = subset[name].weight.data 
            W_metric = torch.abs(W)
            if prune_n != 0:
                W_mask = (torch.zeros_like(W)==1)
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu()
                W_mask = (W_metric<=thresh)

            W[W_mask] = 0

def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    use_cache = model.config.use_cache
    model.config.use_cache = False

    print("loading calibration data")
    dataloader, _ = get_loaders(
        "c4",
        nsamples=args.nsamples,
        seed=args.seed,
        seqlen=512,                      
        tokenizer=tokenizer,
    )
    print("dataset loading complete")

    inps, outs, attention_mask = prepare_calibration_input(
        model, dataloader, device
    )

    mc = model.__class__.__name__
    if   "OPT"     in mc: layers = model.model.decoder.layers
    elif "GPTNeoX" in mc: layers = model.gpt_neox.layers
    elif "GPTNeo"  in mc: layers = model.transformer.h
    elif "Qwen"    in mc: layers = model.model.layers
    elif 'Llama'   in mc: layers = model.model.layers
    else: raise ValueError(f"Unsupported model type {mc}")

    MAX_SEQ = 512

    for i, layer in enumerate(layers):
        subset = find_layers(layer)

        if f"model.layers.{i}" in model.hf_device_map:
            dev = model.hf_device_map[f"model.layers.{i}"]
            inps, outs, attention_mask = (
                inps.to(dev), outs.to(dev),
                None if attention_mask is None else attention_mask.to(dev)
            )

        wrapped = {n: WrappedGPT(m) for n, m in subset.items()}

        def add_batch(name):
            def _hook(_, x, y): wrapped[name].add_batch(x[0].data, y.data)
            return _hook

        hds = [subset[n].register_forward_hook(add_batch(n)) for n in wrapped]

        for j in range(args.nsamples):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)[:, :MAX_SEQ]
                attn = (attention_mask[j].unsqueeze(0)[:, :MAX_SEQ]
                        if attention_mask is not None else None)
                seq  = inp.size(1)
                pos  = torch.arange(seq, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos) 
                outs[j] = layer(                                                     
                                inp,
                                attention_mask=attn,
                                position_ids=pos,
                                position_embeddings=pos_emb      
                            )[0]

        for hd in hds: hd.remove()

        for name, mod in subset.items():
            print(f"pruning layer {i}  name {name}")

            W_metric = (torch.abs(mod.weight.data) *
                        torch.sqrt(wrapped[name].scaler_row.reshape(1, -1)))
            mask = torch.zeros_like(W_metric, dtype=torch.bool)

            if prune_n:                               # N:M
                for col in range(0, W_metric.size(1), prune_m):
                    blk = W_metric[:, col:col+prune_m]
                    idx = torch.topk(blk, prune_n, dim=1,
                                     largest=False).indices
                    mask.scatter_(1, col + idx, True)
            else:                                     # unstructured
                k = int(W_metric.size(1) * args.sparsity_ratio)
                idx = torch.topk(W_metric, k, dim=1,
                                 largest=False).indices
                mask.scatter_(1, idx, True)

            mod.weight.data[mask] = 0

        for j in range(args.nsamples):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)[:, :MAX_SEQ]
                attn = (attention_mask[j].unsqueeze(0)[:, :MAX_SEQ]
                        if attention_mask is not None else None)
                seq  = inp.size(1)
                pos  = torch.arange(seq, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos) 
                outs[j] = layer(                                                     
                                inp,
                                attention_mask=attn,
                                position_ids=pos,
                                position_embeddings=pos_emb      
                            )[0]

        inps, outs = outs, inps        

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

@torch.no_grad()
def prune_sparsegpt(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    print("Starting SparseGPT ...")

    dataloader, _ = get_loaders(
        "c4",
        nsamples=args.nsamples,
        seed=args.seed,
        seqlen=512,                         
        tokenizer=tokenizer,
    )

    use_cache = model.config.use_cache
    model.config.use_cache = False

    mc = model.__class__.__name__
    if   "OPT"     in mc: layers = model.model.decoder.layers
    elif "GPTNeoX" in mc: layers = model.gpt_neox.layers
    elif "GPTNeo"  in mc: layers = model.transformer.h
    elif "Qwen"    in mc: layers = model.model.layers
    elif 'Llama'   in mc: layers = model.model.layers
    else: raise ValueError(f"Unsupported model type {mc}")

    if   "model.embed_tokens" in model.hf_device_map:
        device = model.hf_device_map["model.embed_tokens"]
    elif "transformer.wte" in model.hf_device_map:  # for GPT2 etc
        device = model.hf_device_map["transformer.wte"]
    elif "model.tok_embeddings" in model.hf_device_map:  # for Mistral
        device = model.hf_device_map["model.tok_embeddings"]

    SEQLEN = 512
    NS     = len(dataloader)
    dtype  = next(model.parameters()).dtype

    inps = torch.zeros((NS, SEQLEN, model.config.hidden_size),
                       dtype=dtype, device=device, requires_grad=False)
    outs = torch.zeros_like(inps)
    cache = {"i": 0, "attention_mask": None}

    def grab_input(_, x):
        inps[cache["i"]] = x[0].detach()
        cache["i"] += 1
        return          

    hook0 = layers[0].register_forward_pre_hook(grab_input)
    with torch.no_grad():
        for batch in dataloader:
            model(batch[0][:, :SEQLEN].to(device), use_cache=False)
    hook0.remove()
    cache["attention_mask"] = None
    attention_mask = cache["attention_mask"]

    print("Calibration data ready.")

    for i, layer in enumerate(layers):
        if f"model.layers.{i}" in model.hf_device_map:
            device = model.hf_device_map[f"model.layers.{i}"]
            inps  = inps.to(device)
            outs  = outs.to(device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)

        subset = find_layers(layer)                    
        gpts   = {n: SparseGPT(m) for n, m in subset.items()}

        def add_batch(name):
            def hook(_, x, y): gpts[name].add_batch(x[0].data, y.data)
            return hook

        hooks = [subset[n].register_forward_hook(add_batch(n)) for n in gpts]

        for j in range(NS):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)            # [1, 512, h]
                attn = (attention_mask[j].unsqueeze(0)
                        if attention_mask is not None else None)
                seq  = inp.size(1)
                pos  = torch.arange(seq, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos) 
                outs[j] = layer(                                                     
                                inp,
                                attention_mask=attn,
                                position_ids=pos,
                                position_embeddings=pos_emb      
                            )[0]
        for h in hooks: h.remove()

        for name, sgpt in gpts.items():
            print(f"Layer {i:02d}  {name}  pruning ...")
            sgpt.fasterprune(args.sparsity_ratio,
                             prune_n=prune_n, prune_m=prune_m,
                             percdamp=0.01, blocksize=128)
            sgpt.free()

        for j in range(NS):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)
                attn = (attention_mask[j].unsqueeze(0)
                        if attention_mask is not None else None)
                seq  = inp.size(1)
                pos  = torch.arange(seq, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos) 
                outs[j] = layer(                                                     
                                inp,
                                attention_mask=attn,
                                position_ids=pos,
                                position_embeddings=pos_emb      
                            )[0]

        inps, outs = outs, inps          
        torch.cuda.empty_cache()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    print("SparseGPT finished.")

def prune_pruner_zero(args, model,tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0, engine=None, target_layers=None, layer_sparsity=None):
    use_cache = model.config.use_cache
    model.config.use_cache = False

    dataloader, _ = get_loaders(
        "c4",
        nsamples=args.nsamples,
        seed=args.seed,
        seqlen=512,                      
        tokenizer=tokenizer,
    )
    inps, outs, attention_mask = prepare_calibration_input(
        model, dataloader, device, seqlen=512
    )

    gradients = torch.load(args.gradient_path, map_location="cpu")

    mc = model.__class__.__name__
    if   "OPT"     in mc: layers = model.model.decoder.layers
    elif "GPTNeoX" in mc: layers = model.gpt_neox.layers
    elif "GPTNeo"  in mc: layers = model.transformer.h
    elif "Qwen"    in mc: layers = model.model.layers
    elif 'Llama'   in mc: layers = model.model.layers
    else: raise ValueError(f"Unsupported model type {mc}")

    SEQLEN = 512
    NS     = args.nsamples

    for i, layer in enumerate(layers):
        if target_layers and i not in target_layers:
            continue

        if f"model.layers.{i}" in model.hf_device_map:
            device = model.hf_device_map[f"model.layers.{i}"]
            inps  = inps.to(device)
            outs  = outs.to(device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)

        sp = layer_sparsity.get(i, args.sparsity_ratio) if layer_sparsity else args.sparsity_ratio

        subset  = find_layers(layer)
        wrappers = {n: WrappedGPT(m) for n, m in subset.items()}

        def add_batch(n):
            def hook(_, x, y): wrappers[n].add_batch(x[0].data, y.data)
            return hook

        hooks = [subset[n].register_forward_hook(add_batch(n)) for n in wrappers]

        for j in range(NS):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)             # [1, 512, h]
                attn = (attention_mask[j].unsqueeze(0)
                        if attention_mask is not None else None)
                pos  = torch.arange(SEQLEN, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos) 
                outs[j] = layer(                                                     
                                inp,
                                attention_mask=attn,
                                position_ids=pos,
                                position_embeddings=pos_emb      
                            )[0]
        for h in hooks: h.remove()

        GPTree.PowerExponents.current_layer = i
        for name, mod in subset.items():
            print(f"pruning layer {i:02d}  {name}")

            for layer_idx, (w,g) in GPTree.PowerExponents.layer_exponents.items():
                if layer_idx == i:
                    print(f"layer {layer_idx}: w={w:.4f} g={g:.4f}")

            idx_key = f"{name}_layer_{i}"
            W, X = mod.weight.data.abs(), wrappers[name].scaler_row.reshape(1, -1)
            G    = gradients[idx_key].to(device=mod.weight.device, dtype=torch.float32)

            metric = engine.forward(W.float(), G, X.float())

            mask = torch.zeros_like(metric, dtype=torch.bool)
            if prune_n:                               # N:M
                for col in range(0, metric.size(1), prune_m):
                    blk = metric[:, col:col+prune_m]
                    idx = torch.topk(blk, prune_n, largest=False, dim=1).indices
                    mask.scatter_(1, col+idx, True)
            else:                                     # unstructured
                k = int(metric.size(1) * sp)
                idx = torch.topk(metric, k, largest=False, dim=1).indices
                mask.scatter_(1, idx, True)

            mod.weight.data[mask] = 0

        for j in range(NS):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)
                attn = (attention_mask[j].unsqueeze(0)
                        if attention_mask is not None else None)
                pos  = torch.arange(SEQLEN, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos)
                outs[j] = layer(inp, attention_mask=attn,
                       position_ids=pos,
                       position_embeddings=pos_emb)[0]

        inps, outs = outs, inps
        torch.cuda.empty_cache()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    print("Pruner-Zero finished.")


def prune_x_pruner(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0, engine=None, target_layers=None, layer_sparsity=None):
    use_cache = model.config.use_cache
    model.config.use_cache = False

    dataloader, _ = get_loaders(
        "c4",
        nsamples=args.nsamples,
        seed=args.seed,
        seqlen=512,                       
        tokenizer=tokenizer,
    )
    inps, outs, attention_mask = prepare_calibration_input(
        model, dataloader, device, seqlen=512
    )

    gradients = torch.load(args.gradient_path, map_location="cpu")

    mc = model.__class__.__name__
    if   "OPT"     in mc: layers = model.model.decoder.layers
    elif "GPTNeoX" in mc: layers = model.gpt_neox.layers
    elif "GPTNeo"  in mc: layers = model.transformer.h
    elif "Qwen"    in mc: layers = model.model.layers
    elif 'Llama'   in mc: layers = model.model.layers
    else: raise ValueError(f"Unsupported model type {mc}")

    SEQLEN = 512
    NS     = args.nsamples

    for i, layer in enumerate(layers):
        if target_layers and i not in target_layers:
            continue

        if f"model.layers.{i}" in model.hf_device_map:
            device = model.hf_device_map[f"model.layers.{i}"]
            inps  = inps.to(device)
            outs  = outs.to(device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)

        sp = layer_sparsity.get(i, args.sparsity_ratio) if layer_sparsity else args.sparsity_ratio

        subset  = find_layers(layer)
        wrappers = {n: WrappedGPT(m) for n, m in subset.items()}

        def add_batch(n):
            def hook(_, x, y): wrappers[n].add_batch(x[0].data, y.data)
            return hook

        hooks = [subset[n].register_forward_hook(add_batch(n)) for n in wrappers]

        for j in range(NS):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)             # [1, 512, h]
                attn = (attention_mask[j].unsqueeze(0)
                        if attention_mask is not None else None)
                pos  = torch.arange(SEQLEN, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos) 
                outs[j] = layer(                                                     
                                inp,
                                attention_mask=attn,
                                position_ids=pos,
                                position_embeddings=pos_emb      
                            )[0]
        for h in hooks: h.remove()

        GPTree.PowerExponents.current_layer = i
        for name, mod in subset.items():
            print(f"pruning layer {i:02d}  {name}")

            for layer_idx, (w,g) in GPTree.PowerExponents.layer_exponents.items():
                if layer_idx == i:
                    print(f"layer {layer_idx}: w={w:.4f} g={g:.4f}")

            idx_key = f"{name}_layer_{i}"
            W, X = mod.weight.data.abs(), wrappers[name].scaler_row.reshape(1, -1)
            G    = gradients[idx_key].to(device=mod.weight.device, dtype=torch.float32)

            metric = engine.forward(W.float(), G, X.float())

            mask = torch.zeros_like(metric, dtype=torch.bool)
            if prune_n:                               # N:M
                for col in range(0, metric.size(1), prune_m):
                    blk = metric[:, col:col+prune_m]
                    idx = torch.topk(blk, prune_n, largest=False, dim=1).indices
                    mask.scatter_(1, col+idx, True)
            else:                                     # unstructured
                k = int(metric.size(1) * sp)
                idx = torch.topk(metric, k, largest=False, dim=1).indices
                mask.scatter_(1, idx, True)

            mod.weight.data[mask] = 0

        for j in range(NS):
            with torch.no_grad():
                inp  = inps[j].unsqueeze(0)
                attn = (attention_mask[j].unsqueeze(0)
                        if attention_mask is not None else None)
                pos  = torch.arange(SEQLEN, device=inp.device).unsqueeze(0)
                pos_emb = model.model.rotary_emb(inp, pos)
                outs[j] = layer(inp, attention_mask=attn,
                       position_ids=pos,
                       position_embeddings=pos_emb)[0]

        inps, outs = outs, inps
        torch.cuda.empty_cache()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    print("Pruner-Zero finished.")