import torch
import torch.nn as nn
import tqdm
import gc
import functools
from collections import defaultdict
from typing import List

from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
try:
    from tinychat.models import LlavaLlamaForCausalLM
except ImportError as e:
    pass

from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM

from .auto_scale import auto_scale_block, apply_scale
from .auto_clip import auto_clip_block, apply_clip

__all__ = ["run_awq"]


def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}


def get_blocks(model):
    if model.__class__.__name__ in ("LlamaForCausalLM", "Qwen2ForCausalLM"):
        layers = model.model.layers
    elif model.__class__.__name__ == "InternVL3":
        layers = model.language_model.model.layers
        # layers = [model.language_model.model.layers, model.vision_model.encoder.layers]
    elif model.__class__.__name__ == "LlavaLlamaForCausalLM":
        # layers = [model.model.layers, model.model.vision_tower.vision_tower.vision_model.encoder.layers]
        layers = model.model.layers
    elif isinstance(model, OPTForCausalLM):
        layers = model.model.decoder.layers
    elif isinstance(model, BloomForCausalLM):
        layers = model.transformer.h
    elif "mpt" in str(model.__class__).lower():
        layers = model.transformer.blocks
    elif "falcon" in str(model.__class__).lower():
        layers = model.transformer.h
    elif "bigcode" in str(model.__class__).lower():
        layers = model.transformer.h
    elif "neox" in str(model.__class__).lower():
        layers = model.gpt_neox.layers
    elif model.__class__.__name__ == "LlavaLlamaModel":
        layers = model.llm.model.layers
    else:
        raise NotImplementedError(type(model))
    return layers


def move_embed(model, device):
    if isinstance(model, (LlamaForCausalLM, Qwen2ForCausalLM)):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
        model.model.rotary_emb = model.model.rotary_emb.to(device)
    elif model.__class__.__name__ == "InternVL3":
        model.language_model.model.embed_tokens = (
            model.language_model.model.embed_tokens.to(device)
        )
        model.language_model.model.rotary_emb = (
            model.language_model.model.rotary_emb.to(device)
        )
        model.vision_model.embeddings.to(device)
    elif isinstance(model, LlavaLlamaForCausalLM):
        model.model.embed_tokens = model.model.embed_tokens.to(device)
        model.model.vision_tower.vision_tower.vision_model.embeddings.to(device)
    elif isinstance(model, OPTForCausalLM):
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(
            device
        )
    elif isinstance(model, BloomForCausalLM):
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
        model.transformer.word_embeddings_layernorm = (
            model.transformer.word_embeddings_layernorm.to(device)
        )
    elif "mpt" in str(model.__class__).lower():
        model.transformer.wte = model.transformer.wte.to(device)
        model.transformer.emb_drop = model.transformer.emb_drop.to(device)
    elif "falcon" in str(model.__class__).lower():
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
    elif "bigcode" in str(model.__class__).lower():
        model.transformer.wte = model.transformer.wte.to(device)
        model.transformer.wpe = model.transformer.wpe.to(device)
        model.transformer.drop = model.transformer.drop.to(device)
    elif "neox" in str(model.__class__).lower():
        model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
        model.gpt_neox.emb_dropout = model.gpt_neox.emb_dropout.to(device)
        model.embed_out = model.embed_out.to(device)
    elif "llavallamamodel" in str(model.__class__).lower():
        model.llm.model.embed_tokens = model.llm.model.embed_tokens.to(device)
    else:
        raise NotImplementedError(type(model))


@torch.no_grad()
def run_awq(
    model,
    enc,
    w_bit,
    q_config,
    n_samples=512,
    seqlen=512,
    auto_scale=True,
    mse_range=True,
    # some configs for ablation study
    calib_data="pileval",
):
    from ..utils.calib_data import get_calib_dataset
    from ..utils.module import append_str_prefix, get_op_name

    if "bigcode" in str(model.__class__).lower():
        # otherwise attention_mask will always be on cpu.
        model.transformer.bias = model.transformer.bias.to("cuda")

    layers = get_blocks(model)

    samples = get_calib_dataset(
        data=calib_data, tokenizer=enc, n_samples=n_samples, block_size=seqlen
    )
    samples = torch.cat(samples, dim=0)

    inps = []
    layer_kwargs = {}

    layers[0] = layers[0].cuda()
    move_embed(model, "cuda")

    # get input and kwargs to layer 0
    # with_kwargs is only supported in PyTorch 2.0
    # use this Catcher hack for now
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps.append(inp)
            layer_kwargs.update(kwargs)
            raise ValueError  # early exit to break later inference

    # patch layer 0 to catch input and kwargs
    layers[0] = Catcher(layers[0])
    try:
        if model.__class__.__name__ == "LlavaLlamaModel":
            model.llm(samples.to(next(model.parameters()).device))
        elif model.__class__.__name__ == "InternVL3":
            model.language_model(samples.to(next(model.parameters()).device))
        else:
            model(samples.to(next(model.parameters()).device))
    except ValueError:  # work with early exit
        pass
    del samples
    layers[0] = layers[0].module  # restore
    inps = inps[0]

    layers[0] = layers[0].cpu()
    move_embed(model, "cpu")

    gc.collect()
    torch.cuda.empty_cache()

    awq_results = {
        "scale": [],
        "clip": [],
    }

    # solve layer by layer
    for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
        layer = layers[i]
        layer = layer.cuda()
        named_linears = get_named_linears(layer)

        # firstly, get input features of all linear layers
        def cache_input_hook(m, x, y, name, feat_dict):
            x = x[0]
            x = x.detach().cpu()
            feat_dict[name].append(x)

        input_feat = defaultdict(list)
        handles = []
        for name in named_linears:
            handles.append(
                named_linears[name].register_forward_hook(
                    functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
                )
            )
        inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
        # get output as next layer's input
        inps = layer(inps, **layer_kwargs)[0]
        for h in handles:
            h.remove()
        # now solve for scaling and clipping
        input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}

        # Clear GPU memory
        torch.cuda.empty_cache()

        if (
            auto_scale
        ):  # if it applies, we should also modify the input_feat with scales
            scales_list = auto_scale_block(
                layer,
                layer_kwargs,
                w_bit=w_bit,
                q_config=q_config,
                input_feat=input_feat,
            )
            # apply_scale(layer, scales_list, input_feat_dict=input_feat)
            apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
            # append prefix to make names global
            awq_results["scale"] += append_str_prefix(
                scales_list, get_op_name(model, layer) + "."
            )

        # Clear GPU memory
        torch.cuda.empty_cache()
        # for line in torch.cuda.memory_summary().splitlines():
        #     if "Allocated" in line:
        #         print(line)

        if mse_range:
            clip_list = auto_clip_block(
                layer,
                w_bit=w_bit,
                q_config=q_config,
                input_feat=input_feat,
            )
            apply_clip(layer, clip_list)
            # append prefix to make names global
            awq_results["clip"] += append_str_prefix(
                clip_list, get_op_name(model, layer) + "."
            )

        layer = layer.cpu()
        # Haotian: check activation replacement
        del input_feat
        gc.collect()
        torch.cuda.empty_cache()
        # for line in torch.cuda.memory_summary().splitlines():
        #     if "Allocated" in line:
        #         print(line)

    return awq_results


def distance_to_2pi(L = 2048):
    import numpy as np
    import math
    hid_dim = np.array(range(1, 4097))
    base = 10000
    wavelength = 2 * math.pi * np.power(base, (2*hid_dim)/4096)
    beta_point = np.searchsorted(wavelength, 2048) + 1
    distance = wavelength[beta_point:] - wavelength[beta_point]
    print(f'distance to full 2pi rotation for dimentions after {beta_point} are: ', distance)
    return distance, beta_point

def log_rescale(original_array):
    import numpy as np
    # Original array (0 to 6 million milliamps)
    # original_array = np.array([0, 1.5e6, 3e6, 4.5e6, 6e6])  # Example values
    
    # Small constant to avoid log(0)
    epsilon = 1e-9
    # Step 1: Logarithmic transformation
    log_array = np.log(original_array + epsilon)
    # Log-transformed range
    log_min, log_max = log_array.min(), log_array.max()
    # Desired range
    a, b = 3, 1
    # Step 2: Reverse scaling
    reversed_rescaled_array = a + ((log_max - log_array) * (b - a)) / (log_max - log_min)
    print(f'Reversed Rescaled Array (Logarithmic Scaling) to range ({a}, {b})')
    print(f'NEW Rescale Array {reversed_rescaled_array}')
    return reversed_rescaled_array

def apply_awq(model, awq_results, temp = None, beta_point = 1440, 
              max_tokenlength = 2048, dynamic_with_log_distance = False, exclude_value_proj = False, args = None):
    # try to modify scale list here (Ye)
    # TODO:
    if dynamic_with_log_distance:
        distance, beta_point = distance_to_2pi(max_tokenlength)
        rescale_array = log_rescale(distance)
        print(f'apply rescale based on Logarithmic distance (to full 2pi rotation) after beta_point {beta_point}')
        # print(f'rescale betapoint = {beta_point}')
        new_scale = awq_results["scale"][0][2]
        print('old scale: ', new_scale)
        new_scale[beta_point:] *= rescale_array
        print('new scale: ', new_scale)
        awq_results["scale"][0] = awq_results["scale"][0][:2] + (new_scale,) + awq_results["scale"][0][3:]
    scaling_factor_with_temp = None
    if temp != None: 
        print(f'apply temperature = {temp} to the awq scale')
        print(f'rescale betapoint = {beta_point}')
        scaling_factor_with_temp = 1 + (1 / temp) # can also decayed with channels
        print('awq_results: ', awq_results["scale"][0])
        # print('awq_results: ', awq_results["scale"][0][2].shape)
        
        # for i, scale_layer in enumerate(awq_results["scale"]):
        #     # for j in range(scale_layer[2]):
        #     #     # awq_results["scale"][i][2] = scale[2]*scaling_factor if i >= 1440 else None
        #     new_scale = scale_layer[2]
        #     new_scale[1440:] *= scaling_factor
        #     awq_results["scale"][i] = scale_layer[:2] + (new_scale,) + scale_layer[3:]
    
    
        new_scale = awq_results["scale"][0][2].detach().clone()
        # new_scale[beta_point:] *= scaling_factor
        
        
        # does not use due to move the implementation to apply_scale,
        # but still show top k of scaling factor for the first attention layer
        import numpy as np
        n = new_scale[beta_point:].shape[0]
        print('----in the if tmep != none block----')
        print('scale shape ', awq_results["scale"][0][2].shape)
        print('original scale: ', new_scale)
        
        topk_values, topk_indices = torch.topk(new_scale, k=50)
        print('Top 50 scale values: ', topk_values)
        print('Top 50 scale indices: ', topk_indices)
        
        scaling_factor_array = np.linspace(2.4, 1.4, n)
        new_scale[beta_point:] *= scaling_factor_array
        print('new scale: ', new_scale)
        
        # awq_results["scale"][0] = awq_results["scale"][0][:2] + (new_scale,) + awq_results["scale"][0][3:]
    # print("check if awq_result[\"scale\"] is meta tensor: ", awq_results["scale"][0][2].is_meta)
    # print('awq_results_with_temp: ', awq_results["scale"])
    if exclude_value_proj:
        print('Excluding value projection layer from scaling')
        apply_scale_exvalue(model, awq_results["scale"], new_scale)
    else:
        # awq_results["scale"][0] = awq_results["scale"][0][:2] + (new_scale,) + awq_results["scale"][0][3:]
        apply_scale(model, awq_results["scale"], beta_point = beta_point, args=args, temp = temp)
    apply_clip(model, awq_results["clip"])



def apply_awq_ntk(model, awq_results, temp = None, beta_point = 1440, 
              max_tokenlength = 2048, dynamic_with_log_distance = False, exclude_value_proj = False, args = None):
    # try to modify scale list here (Ye)
    if temp != None: 
        print(f'apply temperature = {temp} to the awq scale')
        # print(f'rescale betapoint = {beta_point}')
        scaling_factor_with_temp = 1 + (1 / temp) # can also decayed with channels
        print('awq_results: ', awq_results["scale"][0])
        # print('awq_results: ', awq_results["scale"][0][2].shape)
    
    
        new_scale = awq_results["scale"][0][2].detach().clone()
        # new_scale[beta_point:] *= scaling_factor
        
        
        # does not use due to move the implementation to apply_scale,
        # but still show top k of scaling factor for the first attention layer
        import numpy as np
        n = new_scale[beta_point:].shape[0]
        print('----in the if tmep != none block----')
        print('scale shape ', awq_results["scale"][0][2].shape)
        print('original scale: ', new_scale)
    
    if exclude_value_proj:
        print('Excluding value projection layer from scaling')
        apply_scale_exvalue(model, awq_results["scale"], new_scale)
    else:
        # awq_results["scale"][0] = awq_results["scale"][0][:2] + (new_scale,) + awq_results["scale"][0][3:]
        apply_scale(model, awq_results["scale"], beta_point = beta_point, args=args, scaling_factor_with_temp = scaling_factor_with_temp)
    apply_clip(model, awq_results["clip"])
    
    
def apply_awq_search(model, awq_results, scale, channel):
    apply_scale_search(model, awq_results["scale"], scale, channel)
    apply_clip(model, awq_results["clip"])
