import gc
import torch
import torch.nn as nn

from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import GELUActivation
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer
from scaled_rope.modeling_llama_yarn import LlamaDecoderLayer as LlamaDecoderLayer_yarn
from scaled_rope.modeling_llama_yarn import LlamaRMSNorm as LlamaRMSNorm_yarn
from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name

__all__ = ["auto_scale_block", "apply_scale"]


@torch.no_grad()
def get_weight_scale(weight, q_group_size=-1):
    org_shape = weight.shape
    if q_group_size > 0:
        weight = weight.view(-1, q_group_size)
    scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
    scale = scale.view(org_shape)
    scale = scale.mean(0)
    return scale


@torch.no_grad()
def get_act_scale(x):
    return x.abs().view(-1, x.shape[-1]).mean(0)


@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
    if not isinstance(fcs, list):
        fcs = [fcs]

    scales = scales.to(ln.weight.device).to(ln.weight.dtype)

    ln.weight.div_(scales)
    if hasattr(ln, "bias") and ln.bias is not None:
        ln.bias.div_(scales)

    for fc in fcs:
        fc.weight.mul_(scales.view(1, -1))

    for p in ln.parameters():
        assert torch.isnan(p).sum() == 0
    for fc in fcs:
        for p in fc.parameters():
            assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)
    # assert fc1.out_features == fc2.in_features

    scales = scales.to(fc1.weight.device).to(fc1.weight.dtype)

    # fc1.weight.div_(scales.view(-1, 1))
    fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
    if fc1.bias is not None:
        fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0


@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
    assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
    assert isinstance(fc, nn.Linear)

    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device).to(fc.weight.dtype))

    for p in fc.parameters():
        assert torch.isnan(p).sum() == 0


@torch.no_grad()
def auto_scale_block(module, module_kwargs, w_bit, q_config, input_feat):
    from .quantizer import pseudo_quantize_tensor

    # firstly, get the weight quantize function
    if w_bit is not None:

        def w_quantize_func(p):
            return pseudo_quantize_tensor(
                p,
                n_bit=w_bit,
                **q_config,
            ).detach()

    else:

        def w_quantize_func(p):
            return p

    if "use_cache" in module_kwargs:
        module_kwargs.pop("use_cache")

    # find the best scale ratio
    def _search_module_scale(block, linears2scale: list, x, kwargs={}):
        # w: co, ci
        # x: n, ci
        x = x.to(next(block.parameters()).device)
        with torch.no_grad():
            org_out = block(x, **kwargs)
            if isinstance(org_out, tuple):
                org_out = org_out[0]

        x_max = get_act_scale(x)

        best_error = float("inf")
        best_ratio = -1
        best_scales = None

        n_grid = 20
        history = []

        org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
        for ratio in range(n_grid):
            ratio = ratio * 1 / n_grid
            scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
            scales = scales / (scales.max() * scales.min()).sqrt()
            for fc in linears2scale:
                fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
                fc.weight.data = w_quantize_func(fc.weight.data) / (scales.view(1, -1))
            out = block(x, **kwargs)
            if isinstance(out, tuple):
                out = out[0]

            loss = (
                (org_out - out).float().pow(2).mean().item()
            )  # float prevents overflow
            history.append(loss)
            is_best = loss < best_error
            if is_best:
                best_error = loss
                best_ratio = ratio
                best_scales = scales
            block.load_state_dict(org_sd)
        if best_ratio == -1:
            print(history)
            raise Exception
        # print(best_ratio)
        best_scales = best_scales.view(-1)

        assert torch.isnan(best_scales).sum() == 0, best_scales
        return best_scales.detach()

    def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
        # module2inspect: if given, we will check the output diff of this module instead of layers
        if module2inspect is None:
            assert len(layers) == 1
            module2inspect = layers[0]

        scales = _search_module_scale(module2inspect, layers, inp, kwargs)
        scales = scales.detach().cpu()
        # prev_op_name, [layer_name], scale
        return (
            get_op_name(module, prev_op),
            tuple([get_op_name(module, m) for m in layers]),
            scales,
        )

    scales_list = []  # return the searched scales

    if isinstance(module, OPTDecoderLayer):
        # attention input
        scales_list.append(
            _auto_get_scale(
                prev_op=module.self_attn_layer_norm,
                layers=[
                    module.self_attn.q_proj,
                    module.self_attn.k_proj,
                    module.self_attn.v_proj,
                ],
                inp=input_feat["self_attn.q_proj"],
                module2inspect=module.self_attn,
                kwargs=module_kwargs,
            )
        )
        # attn out
        scales_list.append(
            _auto_get_scale(
                prev_op=module.self_attn.v_proj,
                layers=[module.self_attn.out_proj],
                inp=input_feat["self_attn.out_proj"],
            )
        )
        # fc1
        scales_list.append(
            _auto_get_scale(
                prev_op=module.final_layer_norm,
                layers=[module.fc1],
                inp=input_feat["fc1"],
            )
        )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.fc1,
                layers=[module.fc2],
                inp=input_feat["fc2"],
            )
        )

    elif isinstance(module, (LlamaDecoderLayer, Qwen2DecoderLayer)):
        # attention input
        scales_list.append(
            _auto_get_scale(
                prev_op=module.input_layernorm,
                layers=[
                    module.self_attn.q_proj,
                    module.self_attn.k_proj,
                    module.self_attn.v_proj,
                ],
                inp=input_feat["self_attn.q_proj"],
                module2inspect=module.self_attn,
                kwargs=module_kwargs,
            )
        )
        # attn out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
            scales_list.append(
                _auto_get_scale(
                    prev_op=module.self_attn.v_proj,
                    layers=[module.self_attn.o_proj],
                    inp=input_feat["self_attn.o_proj"],
                )
            )
        # fc1
        scales_list.append(
            _auto_get_scale(
                prev_op=module.post_attention_layernorm,
                layers=[module.mlp.gate_proj, module.mlp.up_proj],
                inp=input_feat["mlp.gate_proj"],
                module2inspect=module.mlp,
            )
        )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.mlp.up_proj,
                layers=[module.mlp.down_proj],
                inp=input_feat["mlp.down_proj"],
            )
        )

    elif isinstance(module, BloomBlock):
        # attention input
        scales_list.append(
            _auto_get_scale(
                prev_op=module.input_layernorm,
                layers=[module.self_attention.query_key_value],
                inp=input_feat["self_attention.query_key_value"],
                module2inspect=module,
                kwargs=module_kwargs,
            )
        )
        # attn out
        # Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
        """
        scales_list.append(_auto_get_scale(
            prev_op=module.self_attention.query_key_value,
            layers=[module.self_attention.dense],
            inp=input_feat['self_attention.dense'],
        ))
        """
        # fc1
        scales_list.append(
            _auto_get_scale(
                prev_op=module.post_attention_layernorm,
                layers=[module.mlp.dense_h_to_4h],
                inp=input_feat["mlp.dense_h_to_4h"],
                module2inspect=module,
                kwargs=module_kwargs,
            )
        )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.mlp.gelu_impl,
                layers=[module.mlp.dense_4h_to_h],
                inp=input_feat["mlp.dense_4h_to_h"],
            )
        )
    elif "mpt" in str(module.__class__).lower():
        # attention input
        scales_list.append(
            _auto_get_scale(
                prev_op=module.norm_1,
                layers=[module.attn.Wqkv],
                inp=input_feat["attn.Wqkv"],
                module2inspect=module.attn,
                kwargs=module_kwargs,
            )
        )

        # attn out
        scales_list.append(
            _auto_get_scale(
                prev_op=module.attn.Wqkv,
                layers=[module.attn.out_proj],
                inp=input_feat["attn.out_proj"],
            )
        )
        # fc1
        scales_list.append(
            _auto_get_scale(
                prev_op=module.norm_2,
                layers=[module.ffn.up_proj],
                inp=input_feat["ffn.up_proj"],
                module2inspect=module.ffn,
            )
        )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.ffn.act,
                layers=[module.ffn.down_proj],
                inp=input_feat["ffn.down_proj"],
            )
        )

    elif "falcon" in str(module.__class__).lower():
        # attn out
        # Haotian: TBD: need to handle repeated scales for MQ
        """
        scales_list.append(_auto_get_scale(
            prev_op=module.self_attention.query_key_value,
            layers=[module.self_attention.dense],
            inp=input_feat['self_attention.dense'],
        ))
        """
        # fc1, as long as it is scaled, everything is screwed up
        if "falcon-7b" in str(module.__class__).lower():
            scales_list.append(
                _auto_get_scale(
                    prev_op=module.input_layernorm,
                    layers=[
                        module.mlp.dense_h_to_4h,
                        module.self_attention.query_key_value,
                    ],
                    inp=input_feat["self_attention.query_key_value"],
                    module2inspect=module,
                    kwargs=module_kwargs,
                )
            )
        elif "falcon-40b" in str(module.__class__).lower():
            scales_list.append(
                _auto_get_scale(
                    prev_op=module.ln_attn,
                    layers=[module.self_attention.query_key_value],
                    inp=input_feat["self_attention.query_key_value"],
                    module2inspect=module,
                    kwargs=module_kwargs,
                )
            )
            scales_list.append(
                _auto_get_scale(
                    prev_op=module.ln_mlp,
                    layers=[module.mlp.dense_h_to_4h],
                    inp=input_feat["mlp.dense_h_to_4h"],
                    module2inspect=module,
                    kwargs=module_kwargs,
                )
            )
        else:
            raise NotImplementedError(
                "Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported"
            )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.mlp.act,
                layers=[module.mlp.dense_4h_to_h],
                inp=input_feat["mlp.dense_4h_to_h"],
            )
        )
    elif "bigcode" in str(module.__class__).lower():
        scales_list.append(
            _auto_get_scale(
                prev_op=module.ln_1,
                layers=[module.attn.c_attn],
                inp=input_feat["attn.c_attn"],
                module2inspect=module.attn,
                kwargs=module_kwargs,
            )
        )
        # fc1
        scales_list.append(
            _auto_get_scale(
                prev_op=module.ln_2,
                layers=[module.mlp.c_fc],
                inp=input_feat["mlp.c_fc"],
                module2inspect=module.mlp,
            )
        )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.mlp.act,
                layers=[module.mlp.c_proj],
                inp=input_feat["mlp.c_proj"],
            )
        )
    elif "neox" in str(module.__class__).lower():
        scales_list.append(
            _auto_get_scale(
                prev_op=module.input_layernorm,
                layers=[module.attention.query_key_value],
                inp=input_feat["attention.query_key_value"],
                module2inspect=module.attention,
                kwargs=module_kwargs,
            )
        )
        # fc1
        scales_list.append(
            _auto_get_scale(
                prev_op=module.post_attention_layernorm,
                layers=[module.mlp.dense_h_to_4h],
                inp=input_feat["mlp.dense_h_to_4h"],
                module2inspect=module.mlp,
            )
        )
        # fc2
        scales_list.append(
            _auto_get_scale(
                prev_op=module.mlp.act,
                layers=[module.mlp.dense_4h_to_h],
                inp=input_feat["mlp.dense_4h_to_h"],
            )
        )
    else:
        raise NotImplementedError(f"{type(module)} not supported yet!")

    return scales_list


def apply_scale_original(module, scales_list, input_feat_dict=None):
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]

        prev_op.cuda()
        for layer in layers:
            layer.cuda()
        scales.cuda()

        if isinstance(prev_op, nn.Linear):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)
        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm, Qwen2RMSNorm)):
            scale_ln_fcs(prev_op, layers, scales)
        elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation, nn.SiLU)):
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
        else:
            raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
                inp.div_(scales.view(1, -1).to(inp.device).to(inp.dtype))

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
        scales.cpu()
        
def apply_scale(module, scales_list, input_feat_dict=None, beta_point = None, 
                args = None, temp = None):
    
    # try to modify scale list here (Ye)
    # TODO: 
    # print('scales_list: ', scales_list)
    
    import numpy as np
    
        
    for i, (prev_op_name, layer_names, scales) in enumerate(scales_list):
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]
        # print("prev_op_name, layer_names, scales: ",prev_op_name, layer_names, scales)
        # print('prev_op: ', prev_op)
        prev_op.cuda()
        for layer in layers:
            layer.cuda()
        scales.cuda()
        print('prev_op_name: ', prev_op_name)
        print('layer_names: ', layer_names)
        print('scales: ', scales)
        print('scales shape: ', scales.shape)
        
        # rescale for every attention layer
        if args != None: 
            if args.rescale_attention_all:
                if i == 0:
                    print("rescaling all attention layer")
                if layer_names[0].find('q_proj') != -1:
                    print(f"rescaling {i}th attention layer")
                    if args.rescale_per_head:
                        # rescale per head
                        print(f"rescale per head")
                        import math
                        scales = scales.view(32, -1)
                        n = scales[0][math.ceil(beta_point/32):].shape[0]
                        scaling_factor_array_head = np.linspace(2.4, 1.4, n)
                        if temp != None:
                            scaling_factor_with_temp = 1 + (1 / temp)
                            scaling_factor_array_head = scaling_factor_with_temp
                        if args.ntk != None:
                            scaling_factor_array_head = np.linspace(1.4, 2.4, 128)
                            beta_point =0
                            print('ntk re-scaling')
                        
                        print('n: ', n)
                        # print("scaling_factor_array_head: ", scaling_factor_array_head)
                        
                        if args.individual_channel_up != None and args.individual_channel_scale != None:
                            print(f'individully rescale channel numer: {args.individual_channel_up} up with factor {args.individual_channel_scale}')
                            try:
                                individual_channel_list = [int(x) for x in args.individual_channel_up.split(',')]
                            except ValueError:
                                raise ValueError("All elements in the individual_channel list must be integers.")
                            for i in range(32):
                                scales[i][individual_channel_list] *= args.individual_channel_scale
                            
                            # check if surpress some channels
                            if args.individual_channel_down != None:
                                print(f'individully rescale channel numer: {args.individual_channel_down} down with factor {args.individual_channel_scale}')
                                try:
                                    individual_channel_list = [int(x) for x in args.individual_channel_down.split(',')]
                                except ValueError:
                                    raise ValueError("All elements in the individual_channel list must be integers.")
                                for i in range(32):
                                    scales[i][individual_channel_list] /= args.individual_channel_scale
                        
                        # test for individual channel up or down with individual channel value weight
                        elif args.individual_channel_up != None and args.individual_channel_value != None:     
                            print(f'individully rescale channel numer: {args.individual_channel_up} up with values {args.individual_channel_value}')
                            try:
                                individual_channel_list = [int(x) for x in args.individual_channel_up.split(',')]
                            except ValueError:
                                raise ValueError("All elements in the individual_channel list must be integers.")
                            try:
                                individual_value_list = [float(x) for x in args.individual_channel_value.split(',')]
                                
                                #normlization
                                individual_value_list = -torch.tensor(individual_value_list, dtype=torch.float32)
                                min_val = individual_value_list.min()
                                max_val = individual_value_list.max()
                                normalized_values = 1.4 + (individual_value_list - min_val) * (2.4 - 1.4) / (max_val - min_val)
                                if args.scale_invert:
                                        normalized_values = 2.4-normalized_values + 1.4
                                individual_value_list = normalized_values
                                print('individual_value_list_normlized: ', individual_value_list)
                            except ValueError:
                                raise ValueError("All elements in the individual_value list must be float.")
                            for i in range(32):
                                scales[i][individual_channel_list] *= torch.tensor(individual_value_list, device=scales.device)
                            
                            # check if surpress some channels (todo)
                            if args.individual_channel_down != None:
                                print(f'individully rescale channel numer: {args.individual_channel_down} down with values {args.individual_channel_value}')
                                try:
                                    individual_channel_list = [int(x) for x in args.individual_channel_down.split(',')]
                                except ValueError:
                                    raise ValueError("All elements in the individual_channel list must be integers.")
                                for i in range(32):
                                    scales[i][individual_channel_list] /= args.individual_channel_value
                                            
                            # test for individual channel up or down with individual channel value weight using search result
                        elif args.use_search_result and args.search_result_path != None:     
                            print(f'individully rescale use search result')
                            individual_value_list = []
                            with open(args.search_result_path, "r") as f:
                                for line in f:
                                    # Example line: best scale for channel 0 is 2.0 with ppl_difference -32.633899211883545
                                    parts = line.strip().split()
                                    if "is" in parts:
                                        idx = parts.index("is")
                                        scale = float(parts[idx + 1])
                                        if scale == 0:
                                            scale = 1.0
                                        individual_value_list.append(scale)
                            print('individual_scale_list: ', individual_value_list)
                            individual_channel_list = range(128)
                                
                            for i in range(32):
                                scales[i][individual_channel_list] *= torch.tensor(individual_value_list, device=scales.device)
                            
                            # check if surpress some channels (todo)
                            if args.individual_channel_down != None:
                                print(f'individully rescale channel numer: {args.individual_channel_down} down with values {args.individual_channel_value}')
                                try:
                                    individual_channel_list = [int(x) for x in args.individual_channel_down.split(',')]
                                except ValueError:
                                    raise ValueError("All elements in the individual_channel list must be integers.")
                                for i in range(32):
                                    scales[i][individual_channel_list] /= args.individual_channel_value
                        else:
                            print(f'rescale all channels in heads')
                            for i in range(32):
                                scales[i][math.ceil(beta_point/32):] *= scaling_factor_array_head
                            
                        scales = scales.view(-1)
                    else:                
                        n = scales[beta_point:].shape[0]
                        scaling_factor_array = np.linspace(2.4, 1.4, n)
                        scales[beta_point:] *= scaling_factor_array
                        # print('new scale: ', new_scale)
                    
                    
            elif args.recale_specific_layer == i:
                print(f"rescaling first attention layer only")
                print(f"beta_point: {beta_point}")
                print('layer_names', layer_names)
                print('original scale(inside version): ', scales)
                if args.rescale_per_head:
                    # rescale per head
                    print(f"rescale per head")
                    import math
                    scales = scales.view(32, -1)
                    n = scales[0][math.ceil(beta_point/32):].shape[0]
                    scaling_factor_array_head = np.linspace(1.4, 2, n)
                    # scaling_factor_array_head = 1+1.1*np.exp(-1*scaling_factor_array_head)[::-1]
                    if args.ntk != None:
                        scaling_factor_array_head = np.linspace(1.4, 2.4, 128)
                        beta_point =0
                        print('ntk re-scaling')
                    if temp != None:
                        scaling_factor_with_temp = 1 + (1 / temp)
                        scaling_factor_array_head = scaling_factor_with_temp
                    print('n: ', n)
                    print("scaling_factor_array_head: ", scaling_factor_array_head)
                    
                    if args.individual_channel != None and args.individual_channel_scale != None:
                        print(f'individully rescale channel numer: {args.individual_channel} with factor {args.individual_channel_scale}')
                        for i in range(32):
                            scales[i][args.individual_channel] *= args.individual_channel_scale
                    else:
                        print(f'rescale all channels in heads')
                        for i in range(32):
                            scales[i][math.ceil(beta_point/32):] *= scaling_factor_array_head
                        
                        
                    scales = scales.view(-1)
                else:
                    n = scales[beta_point:].shape[0]
                    scaling_factor_array = np.linspace(2.25, 1.5, n)
                    scales[beta_point:] *= scaling_factor_array
                print('new scale(inside version): ', scales)
            else:
                print(f"not rescaling anything, appy original awq scale")
            
            
        from model_loader import HadamardLinear
        if isinstance(prev_op, (nn.Linear, HadamardLinear)):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales)
        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm, Qwen2RMSNorm)):
            scale_ln_fcs(prev_op, layers, scales)
            
        # add support for yarn
        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm_yarn)):
            scale_ln_fcs(prev_op, layers, scales)
            
        elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
        else:
            raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")

        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
                inp.div_(scales.view(1, -1).to(inp.device))

        prev_op.cpu()
        for layer in layers:
            layer.cpu()
        scales.cpu()