import torch
import tqdm
import data_utils
import utils
import numpy as np
import random
from torch.nn.utils import prune
import os
import transformers
from datasets import load_dataset

def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_num_threads(1)
    
def check_sparsity(model):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    layers = model.model.layers
    count = 0 
    total_params = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        sub_count = 0
        sub_params = 0
        for name in subset:
            W = subset[name].weight.data
            count += (W==0).sum().item()
            total_params += W.numel()

            sub_count += (W==0).sum().item()
            sub_params += W.numel()

        print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")

    model.config.use_cache = use_cache 
    return float(count)/total_params     

def prepare_calibration_input(model, dataloader, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    model = model.to(device)
    for batch in dataloader:
        try:
            model(batch[0].to(device))
        except ValueError:
            pass 
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids 

def find_layers(module, layers=[torch.nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def prune_model(model, args, after):
    """
    Prune the model based on the provided arguments.
    
    Args:
        model: The model to be pruned.
        args: Arguments containing pruning configurations.
    """
    model = convert(model).eval()
    target_subids = []
    total_sublayers = len(model.model.layers) * 2
    fix_seed(args.seed)
    
    if args.prune_method == "sparsegpt":
        from sparsegpt_utils import SparseGPT
        import quant_utils
        
        dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"dev: {dev}")
        calibloader = data_utils.get_loaders(
            args.eval_dataset,
            seed=args.seed,
            model=args.model,
            seqlen=model.seqlen,
            hf_token=args.hf_token,
            eval_mode=False
        )

        layers = model.model.layers
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.norm = model.model.norm.to(dev)
        model.model.layers[0] = model.model.layers[0].to(dev)
        model.model.rotary_emb = model.model.rotary_emb.to(dev)

        inps = []
        cache = {'i': 0, 'attention_mask': None, 'position_ids': None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
            def forward(self, inp, **kwargs):
                inps.append(inp)
                cache['i'] += 1
                attn = kwargs.get('attention_mask', None)
                pos = kwargs.get('position_ids', None)
                if attn is not None:
                    cache['attention_mask'] = attn.to(dev)
                if pos is not None:
                    cache['position_ids'] = pos.to(dev)
                raise ValueError

        layers[0] = Catcher(layers[0])
        count = 0
        for batch in calibloader:
            print(f"Processing batch {count}")
            try:
                batch = tuple(t.to(dev) for t in batch)
                input_ids = batch[0].to(dev)
                attention_mask = cache['attention_mask']
                position_ids = cache['position_ids']
                # Move cached masks to dev if needed
                if attention_mask is not None:
                    attention_mask = attention_mask.to(dev)
                if position_ids is not None:
                    position_ids = position_ids.to(dev)
                model(input_ids, attention_mask=attention_mask, position_ids=position_ids)
            except ValueError:
                break
            count += 1

        layers[0] = layers[0].module
        model.model.embed_tokens = model.model.embed_tokens.cpu()
        model.model.norm = model.model.norm.cpu()
        model.model.layers[0] = model.model.layers[0].cpu()
        model.model.rotary_emb = model.model.rotary_emb.cpu()
        torch.cuda.empty_cache()

        target_subids = []
        norms = []

        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.norm = model.model.norm.to(dev)
        model.model.rotary_emb = model.model.rotary_emb.to(dev)
        for name, buf in model.model.rotary_emb.named_buffers():
            setattr(model.model.rotary_emb, name, buf.to(dev))

        inps = []
        cache = {'i': 0, 'attention_mask': None, 'position_ids': None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
            def forward(self, inp, **kwargs):
                inps.append(inp)
                cache['i'] += 1
                cache['attention_mask'] = kwargs.get('attention_mask', None)
                cache['position_ids'] = kwargs.get('position_ids', None)
                raise ValueError

        layers[0] = Catcher(layers[0])
        for batch in calibloader:
            try:
                batch = tuple(t.to(dev) for t in batch)
                model(batch[0])
            except ValueError:
                break
        layers[0] = layers[0].module
        layers[0] = layers[0].cpu()
        torch.cuda.empty_cache()

        hidden_states = inps[0]
        attention_mask = cache['attention_mask']
        position_ids = cache['position_ids']

        for i in tqdm.tqdm(range(len(layers))):
            model.model.layers[i] = model.model.layers[i].to(dev)
            layer = model.model.layers[i]
            
            with torch.no_grad():
                hidden_states = layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )[0]

            subset = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
            for name, module in subset.items():
                inputs = []

                def save_input(mod, inp, out):
                    inputs.append(inp[0].detach())

                handle = module.register_forward_hook(save_input)

                with torch.no_grad():
                    _ = layer(
                        inps[0].to(dev),
                        attention_mask=attention_mask,
                        position_ids=position_ids
                    )

                handle.remove()

                if len(inputs) == 0:
                    print(f"[WARNING] {name} hook did not collect any inputs.")
                    continue

                sparsegpt = SparseGPT(module)
                for inp in inputs:
                    sparsegpt.add_batch(inp, None)
                sparsegpt.sparseprune(sparsity=args.prune_ratio, actorder=args.act_order)
                sparsegpt.free()

            model.model.layers[i] = model.model.layers[i].cpu()
            torch.cuda.empty_cache()

        print(f"Pruned {args.prune_ratio * 100}% based on {args.prune_method} pruning")

    else:
        raise ValueError(f"Unknown pruning method: {args.prune_method}")
        
    target_subids.sort()
        
    for target_subid in target_subids:
        turn_off(model, target_subid)                
    utils.cleanup_memory(verbos=True)
    return model

"""
File: onoff.py
- An implementation of OnOff_LlamaDecoderLayer for sublayer pruning
- This source code is written based on the following GitHub repository:
    https://github.com/jiwonsong-dev/SLEB
"""

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
import warnings

def get_device(
    obj: Union[torch.Tensor, nn.Module]
) -> Union[str, torch.device]:
    """
    Get the device of a given obj (tensor of module)
    """

    if isinstance(obj, torch.Tensor):
        return obj.device
    return next(obj.parameters()).device

def move_inputs(
    inps: torch.Tensor,
    atts: torch.Tensor,
    pos: torch.Tensor,
    layer,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Move the device of inputs
    """

    device = get_device(layer)
    if inps is not None and inps.device != device:
        inps = inps.to(device)
    if atts is not None and atts.device != device:
        atts = atts.to(device)
    if pos is not None and pos.device != device:
        pos = pos.to(device)    
    return inps, atts, pos

class OnOff_LlamaDecoderLayer(nn.Module):
    def __init__(self, original_decoder_layer, rotary_emb=None):
        super().__init__()
        self.hidden_size = original_decoder_layer.hidden_size
        self.self_attn = original_decoder_layer.self_attn
        self.mlp = original_decoder_layer.mlp
        self.input_layernorm = original_decoder_layer.input_layernorm
        self.post_attention_layernorm = original_decoder_layer.post_attention_layernorm
        self.rotary_emb = rotary_emb

        self.pass_layer = False
        self.pass_mha = False
        self.pass_mlp = False

    def turn_off(self, mha=True, mlp=True):
        """
        Turn off either an MHA or an MLP sublayer
        """
        if mha:
            self.pass_mha = True
        if mlp:
            self.pass_mlp = True
        self.pass_layer = (self.pass_mha and self.pass_mlp)
    
    def turn_on(self, mha=True, mlp=True):
        """
        Turn on either an MHA or an MLP sublayer
        """
        if mha:
            self.pass_mha = False
        if mlp:
            self.pass_mlp = False
        self.pass_layer = (self.pass_mha and self.pass_mlp)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        The forward function of OnOff_LlamaDecoderLayer

        Args:
            hidden_states (torch.Tensor): a hidden state
            attention_mask (Optional[torch.Tensor], optional): an attention mask. Defaults to None.
            position_ids (Optional[torch.LongTensor], optional): a position ids. Defaults to None.
            past_key_value (Optional[Tuple[torch.Tensor]], optional): past key and values. Defaults to None.
            output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers.. Defaults to False.
            use_cache (Optional[bool], optional): use KV cache or not. Defaults to False.

        Returns:
            outputs: a tuple of outputs
        """
        # skip this decoder layer
        if self.pass_layer:
            outputs = (hidden_states,)

            if output_attentions:
                outputs += (None,)

            if use_cache:
                outputs += (past_key_value,)

            return outputs
        
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        # Self Attention
        if self.pass_mha:
            # skipping mha
            self_attn_weights = None
            present_key_value = past_key_value
        else:
            hidden_states, attention_mask, position_ids = \
            move_inputs(hidden_states, attention_mask, position_ids, self.self_attn)
            residual = hidden_states

            hidden_states = self.input_layernorm(hidden_states)
            
            attn_kwargs = {
                "hidden_states": hidden_states,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "past_key_value": past_key_value,
                "output_attentions": output_attentions,
                "use_cache": use_cache,
            }
            if "position_embeddings" not in kwargs:
                cos, sin = self.rotary_emb(hidden_states, position_ids)
                attn_kwargs["position_embeddings"] = (cos, sin)
            
            for k, v in kwargs.items():
                if k not in attn_kwargs:
                    attn_kwargs[k] = v
            
            attn_out = self.self_attn(**attn_kwargs)

            if len(attn_out) == 3:
                hidden_states, self_attn_weights, present_key_value = attn_out
            elif len(attn_out) == 2:
                hidden_states, self_attn_weights = attn_out
                present_key_value = None
            
            if residual.device != hidden_states.device:
                residual = residual.to(hidden_states.device)

            hidden_states = residual + hidden_states

        # Fully Connected
        if self.pass_mlp:
            pass
        else:
            hidden_states, attention_mask, position_ids = \
            move_inputs(hidden_states, attention_mask, position_ids, self.mlp)
            residual = hidden_states
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)

            if residual.device != hidden_states.device:
                residual = residual.to(hidden_states.device)

            hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

    def do_mlp_forward(self, hidden_states):
        """
        A forward function for MLP sublayers
        """
        if self.pass_mlp:
            return hidden_states
        # Align devices of inputs and weights before computating
        hidden_states, _, _ = move_inputs(
            hidden_states, None, None, self
        )

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        if residual.device != hidden_states.device:
            residual = residual.to(hidden_states.device)

        hidden_states = residual + hidden_states
        return hidden_states
        

    def do_mha_forward(self, hidden_states, attention_mask=None, position_ids=None,
                       past_key_value=None, output_attentions=False, use_cache=False):
        """
        A forward function for MHA sublayers
        """
        if self.pass_mha:
            return hidden_states, None, None
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        # Align devices of inputs and weights before computating
        hidden_states, attention_mask, position_ids = move_inputs(
            hidden_states, attention_mask, position_ids, self
        )
        
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        if residual.device != hidden_states.device:
            residual = residual.to(hidden_states.device)

        hidden_states = residual + hidden_states
        return hidden_states, self_attn_weights, present_key_value

def convert(model):
    """
    Convert the layers into a model into OnOff_LlamaDecoderLayers
    """
    rotary_emb = model.model.rotary_emb
    num_layers = len(model.model.layers)
    for i in range(num_layers):
        model.model.layers[i] = OnOff_LlamaDecoderLayer(model.model.layers[i], rotary_emb=rotary_emb)
    return model

def turn_off(model, sublayer_idx):
    """
    Turn off the target sublayer
    """
    block_idx = sublayer_idx // 2
    is_mha = (sublayer_idx%2 == 0)
    model.model.layers[block_idx].turn_off(mha=is_mha, mlp=not is_mha)

def turn_on(model, sublayer_idx):
    """
    Turn on the target sublayer
    """
    block_idx = sublayer_idx // 2
    is_mha = (sublayer_idx%2 == 0)
    model.model.layers[block_idx].turn_on(mha=is_mha, mlp=not is_mha)