import torch
import typing
import transformers
import os
import logging
import functools

from . import utils

LLAMA_MODEL = transformers.models.llama.modeling_llama.LlamaForCausalLM
LLAMA_LAYER = transformers.models.llama.modeling_llama.LlamaDecoderLayer

QWEN2_VL_MODEL = transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration
QWEN2_VL_LAYER = transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLDecoderLayer

QWEN2_5_VL_MODEL = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
QWEN2_5_VL_LAYER = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLDecoderLayer


def get_model_class(model_path):
    if "Qwen2-VL" in model_path:
        return "qwen2_vl"
    elif "Qwen2.5-VL" in model_path:
        return "qwen2_5_vl"
    else:
        raise ValueError(f"Unknown model type for {model_path}. Please specify the model type explicitly.")
    
def model_type_extractor(model):
    if isinstance(model, LLAMA_MODEL):
        return LLAMA_MODEL
    elif isinstance(model, QWEN2_VL_MODEL):
        return QWEN2_VL_MODEL
    else:
        raise ValueError(f'Unknown model type {model}')

def skip(*args, **kwargs):
    # This is a helper function to save time during the initialization! 
    pass

def get_rope_function_name(model):
    if isinstance(model, LLAMA_MODEL):
        return "apply_rotary_pos_emb"
    raise NotImplementedError


def get_layers(model):
    if isinstance(model, LLAMA_MODEL):
        return model.model.layers
    if isinstance(model, QWEN2_VL_MODEL):
        return model.model.layers
    raise NotImplementedError


def get_llama(model_name, hf_token):
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    model = transformers.LlamaForCausalLM.from_pretrained(model_name, torch_dtype='auto',
                                                          use_auth_token=hf_token,
                                                          low_cpu_mem_usage=True)
    model.seqlen = 2048
    logging.info('---> Loading {} Model with seq_len: {}'.format(model_name, model.seqlen))
    return model


def get_qwen2_vl(model_name):
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(model_name,
                                                                         torch_dtype='auto',
                                                                         device_map='cpu')
    logging.info('---> Loading {} Model'.format(model_name))
    return model


def get_model(
    model_name, hf_token=None
):
    if 'llama' in model_name:
        return get_llama(model_name, hf_token)
    elif 'Qwen2-VL' in model_name:
        return get_qwen2_vl(model_name)
    else:
        raise ValueError(f'Unknown model {model_name}')


def get_model_type(model):
    if isinstance(model, QWEN2_VL_MODEL):
        model_type = QWEN2_VL_MODEL
    elif isinstance(model, LLAMA_MODEL):
        model_type = LLAMA_MODEL
    else:
        raise ValueError(f'Unknown model type {model}')
    return model_type

def get_embeddings(model, model_type) -> list[torch.nn.Module]:
    if model_type == LLAMA_MODEL:
        return [model.model.embed_tokens]
    elif model_type == QWEN2_VL_MODEL:
        return [model.model.embed_tokens]
    else:
        raise ValueError(f'Unknown model type {model_type}')


def get_transformer_layers(model, model_type):
    if model_type == LLAMA_MODEL:
        return [layer for layer in model.model.layers]
    elif model_type == QWEN2_VL_MODEL:
        return [layer for layer in model.model.layers]
    else:
        raise ValueError(f'Unknown model type {model_type}')


def get_lm_head(model, model_type):
    if model_type == LLAMA_MODEL:
        return model.lm_head
    elif model_type == QWEN2_VL_MODEL:
        return model.lm_head
    else:
        raise ValueError(f'Unknown model type {model_type}')

def get_pre_head_layernorm(model, model_type):
    if model_type == LLAMA_MODEL:
        pre_head_layernorm = model.model.norm
        assert isinstance(pre_head_layernorm,
                          transformers.models.llama.modeling_llama.LlamaRMSNorm)
    if model_type == QWEN2_VL_MODEL:
        pre_head_layernorm = model.model.norm
        assert isinstance(pre_head_layernorm,
                          transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2RMSNorm)
    else:
        raise ValueError(f'Unknown model type {model_type}')
    return pre_head_layernorm

def get_mlp_bottleneck_size(model):
    model_type = get_model_type(model)
    if model_type == LLAMA_MODEL:
        return model.config.intermediate_size
    elif model_type == QWEN2_VL_MODEL:
        return model.config.intermediate_size
    else:
        raise ValueError(f'Unknown model type {model_type}')


# same with llm
def replace_modules(
    root: torch.nn.Module,
    type_to_replace,
    new_module_factory,
    replace_layers: bool,
) -> None:
    """Replace modules of given type using the supplied module factory.

    Perform a depth-first search of a module hierarchy starting at root
    and replace all instances of type_to_replace with modules created by
    new_module_factory. Children of replaced modules are not processed.

    Args:
        root: the root of the module hierarchy where modules should be replaced
        type_to_replace: a type instances of which will be replaced
        new_module_factory: a function that given a module that should be replaced
            produces a module to replace it with.
    """
    for name, module in root.named_children():
        new_module = None
        if isinstance(module, type_to_replace):
            if replace_layers:  # layernorm_fusion.replace_layers case where transformer layers are replaced
                new_module = new_module_factory(module, int(name))
            else:  # layernorm_fusion.fuse_modules case where layernorms are fused
                new_module = new_module_factory(module)
        elif len(list(module.children())) > 0:
            replace_modules(module, type_to_replace, new_module_factory, replace_layers)

        if new_module is not None:
            setattr(root, name, new_module)


class RMSN(torch.nn.Module):
    """
    This class implements the Root Mean Square Normalization (RMSN) layer.
    We use the implementation from LLAMARMSNorm here:
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L75
    """

    def __init__(self, mean_dim: int, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.mean_dim = mean_dim
        self.weight = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_dtype = x.dtype
        if x.dtype == torch.float16:
            x = x.to(torch.float32)
        variance = x.pow(2).sum(-1, keepdim=True) / self.mean_dim
        x = x * torch.rsqrt(variance + self.eps)
        return x.to(input_dtype)


class FPInputsCache:
    """
    class for saving the full-precision output in each layer.
    """
    def __init__(self, sequential):
        self.fp_cache = {}
        self.names = sequential[0]+sequential[1]+sequential[2]+sequential[3]
        for name in self.names:
            self.fp_cache[name] = []
        self.handles = []

    def cache_fp_input(self, m, inp, out, name):
        inp = inp[0].detach()
        if len(inp.shape) == 3:
            inp = inp.reshape((-1, inp.shape[-1]))
        self.fp_cache[name] += [inp.t()]

    def add_hook(self, full):
        for name in self.names:
            self.handles.append(
                full[name].register_forward_hook(
                    functools.partial(self.cache_fp_input, name=name)
                )
            )

    def clear_hook(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        torch.cuda.empty_cache()

    def clear_cache(self):
        for name in self.names:
            self.fp_cache[name] = []
            
            
class OutputGradCache:
    def __init__(self, sequential, grad_acton, grad_norm, grad_clip, grad_temperature, grad_clip_times, grad_clip_high_only):
        self.grad_cache = {}
        self.grad_clip = grad_clip
        self.grad_norm = grad_norm
        self.grad_temperature = grad_temperature
        self.grad_clip_times = grad_clip_times
        self.grad_clip_high_only = grad_clip_high_only
        
        self.names_all = sequential[0] + sequential[1] + sequential[2] + sequential[3]
        
        if grad_acton == 'qkv':
            self.names_to_hook = sequential[0]
        elif grad_acton == 'qkvo':
            self.names_to_hook = sequential[0] + sequential[1]
        elif grad_acton == 'all':
            self.names_to_hook = sequential[0] + sequential[1] + sequential[2] + sequential[3]
        else:
            raise ValueError(f"Unsupported grad_acton: {grad_acton}")
        
        for name in self.names_all:
            self.grad_cache[name] = []
        self.handles = []

    def cache_output_grad(self, module, grad_input, grad_output, name):
        grad_o = grad_output[0].detach()  # (bsz, seqlen, ch)
        # print(f'grad_o shape: {grad_o.shape}, name: {name}')
        if len(grad_o.shape) == 3:
            grad_o = grad_o.reshape((-1, grad_o.shape[-1]))

        if self.grad_norm == 'l1':
            grad_o = grad_o.abs().mean(-1)  # (bsz*seqlen,)
        elif self.grad_norm == 'l2':
            grad_o = grad_o.abs().pow(2).sum(-1).pow(0.5)
        else:
            raise ValueError(f"Unsupported grad_norm: {self.grad_norm}")
        
        grad_o = grad_o.numel() * grad_o / grad_o.sum()
        
        if self.grad_clip:
            if self.grad_clip_high_only:
                grad_o = torch.clamp(grad_o, max=self.grad_clip_times)
            else:
                grad_o = torch.clamp(grad_o, max=self.grad_clip_times, min=1/self.grad_clip_times) 
            grad_o = grad_o.numel() * grad_o / grad_o.sum()
            # while grad_o.max() > 10 or grad_o.min() < 0.1:
            #     grad_o = torch.clamp(grad_o, max=10, min=0.1) 
            #     grad_o = grad_o.numel() * grad_o / grad_o.sum() 
        
        if self.grad_temperature != 1.0: 
            grad_o = grad_o.pow(self.grad_temperature) 
            grad_o = grad_o.numel() * grad_o / grad_o.sum()
        
           
        self.grad_cache[name].append(grad_o)

    def add_hook(self, full):
        for name in self.names_to_hook:
            self.handles.append(
                full[name].register_full_backward_hook(
                    functools.partial(self.cache_output_grad, name=name)
                )
            )
            
    def clear_hook(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        torch.cuda.empty_cache()

    def clear_cache(self):
        for name in self.names_all:
            self.grad_cache[name] = []


def get_layer_io_save_path(args):
    return os.path.join(args.save_path, 'layer_io', f'{args.layer_idx:03d}.pt')


def capture_layer_io(model_type, layer, layer_input):
    def hook_factory(module_name, captured_vals, is_input):
        def hook(module, input, output):
            if is_input:
                captured_vals[module_name].append(input[0].detach().cpu())
            else:
                captured_vals[module_name].append(output.detach().cpu())
        return hook

    handles = []

    if model_type == LLAMA_MODEL:
        captured_inputs = {
            'k_proj': [],  # q_proj, v_proj has the same input as k_proj
            'o_proj': [],
            'gate_proj': [],  # up_proj has the same input as gate_proj
            'down_proj': []
        }

        captured_outputs = {
            'v_proj': [],
        }

        for name in captured_inputs.keys():
            module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
            handles.append(module.register_forward_hook(hook_factory(name, captured_inputs, True)))

        for name in captured_outputs.keys():
            module = getattr(layer.self_attn, name, None) or getattr(layer.mlp, name, None)
            handles.append(module.register_forward_hook(hook_factory(name, captured_outputs, False)))

    # Process each sequence in the batch one by one to avoid OOM.
    for seq_idx in range(layer_input.shape[0]):
        # Extract the current sequence across all dimensions.
        seq = layer_input[seq_idx:seq_idx + 1].to(utils.DEV)
        # Perform a forward pass for the current sequence.
        layer(seq)

    # After processing all sequences, concatenate the accumulated inputs for each sub-layer across the batch.
    for module_name in captured_inputs:
        captured_inputs[module_name] = torch.cat(captured_inputs[module_name], dim=0)
    for module_name in captured_outputs:
        captured_outputs[module_name] = torch.cat(captured_outputs[module_name], dim=0)

    # Cleanup.
    for h in handles:
        h.remove()

    return {
        'input': captured_inputs,
        'output': captured_outputs
    }


from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
'''
From AWQ repo: https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/pre_quant.py
'''

def get_blocks(model):
    if model.__class__.__name__ in ("LlamaForCausalLM", "Qwen2ForCausalLM"):
        layers = model.model.layers
    elif model.__class__.__name__ == "LlavaLlamaForCausalLM":
        # layers = [model.model.layers, model.model.vision_tower.vision_tower.vision_model.encoder.layers]
        layers = model.model.layers
    elif model.__class__.__name__ == "LlavaQwenForCausalLM":
        # layers = [model.model.layers, model.model.vision_tower.vision_tower.vision_model.encoder.layers]
        layers = model.model.layers
    elif model.__class__.__name__ == "LlavaLlamaModel":
        layers = model.llm.model.layers
    elif model.__class__.__name__ == "Qwen2VLForConditionalGeneration":
        layers = model.model.layers
    elif model.__class__.__name__ == "Qwen2_5_VLForConditionalGeneration":
        layers = model.model.layers
    else:
        raise NotImplementedError(type(model))
    return layers

def move_embed(model, device):
    if model.__class__.__name__ in ("LlamaForCausalLM", "Qwen2ForCausalLM"):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
        model.model.rotary_emb = model.model.rotary_emb.to(device)
        model.model.norm = model.model.norm.to(device)
        
    elif model.__class__.__name__ == "LlavaLlamaForCausalLM":
        model.llm.model.embed_tokens = model.llm.model.embed_tokens.to(device)
        
    else:
        raise NotImplementedError(type(model))
