import math
import time
import tqdm
import torch
import torch.nn as nn
import logging
import os

from .quant_utils import WeightQuantizer
from .train_utils import set_embed_to_device
from .quant_linear import QuantizedLinear

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

def cleanup_memory(verbose=True) -> None:
    """Clear GPU memory by running garbage collection and emptying cache."""
    import gc
    import inspect
    caller_name = ''
    try:
        caller_name = f' (from {inspect.stack()[1].function})'
    except (ValueError, KeyError):
        pass

    def total_reserved_mem() -> int:
        return sum(torch.cuda.memory_reserved(device=i) for i in range(torch.cuda.device_count()))

    memory_before = total_reserved_mem()

    # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        memory_after = total_reserved_mem()
        if verbose:
            logging.info(
                f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB"
                f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)"
            )

def find_qlayers(module, layers=[torch.nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_qlayers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


class GPTQ:
    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0

    def add_batch(self, inp, out):
        
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if len(inp.shape) == 3:
            inp = inp.reshape((-1, inp.shape[-1]))
        inp = inp.t()
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        # inp = inp.float()
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        # self.H += 2 / self.nsamples * inp.matmul(inp.t())
        self.H += inp.matmul(inp.t())

    def fasterquant(
        self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False
    ):
        max_retries = 5  # 最大重试次数
        current_percdamp = percdamp
        
        while max_retries > 0:
            try:
                W = self.layer.weight.data.clone()
                W = W.float()

                tick = time.time()

                if not self.quantizer.ready():
                    self.quantizer.find_params(W)

                H = self.H
                # del self.H
                dead = torch.diag(H) == 0
                H[dead, dead] = 1
                W[:, dead] = 0

                if static_groups:
                    import copy
                    groups = []
                    for i in range(0, self.columns, groupsize):
                        quantizer = copy.deepcopy(self.quantizer)
                        quantizer.find_params(W[:, i:(i + groupsize)])
                        groups.append(quantizer)

                if actorder:
                    perm = torch.argsort(torch.diag(H), descending=True)
                    W = W[:, perm]
                    H = H[perm][:, perm]
                    invperm = torch.argsort(perm)

                Losses = torch.zeros_like(W)
                Q = torch.zeros_like(W)

                damp = current_percdamp * torch.mean(torch.diag(H))
                diag = torch.arange(self.columns, device=self.dev)
                H[diag, diag] += damp
                H = torch.linalg.cholesky(H)
                H = torch.cholesky_inverse(H)
                H = torch.linalg.cholesky(H, upper=True)
                Hinv = H

                for i1 in range(0, self.columns, blocksize):
                    i2 = min(i1 + blocksize, self.columns)
                    count = i2 - i1

                    W1 = W[:, i1:i2].clone()
                    Q1 = torch.zeros_like(W1)
                    Err1 = torch.zeros_like(W1)
                    Losses1 = torch.zeros_like(W1)
                    Hinv1 = Hinv[i1:i2, i1:i2]

                    for i in range(count):
                        w = W1[:, i]
                        d = Hinv1[i, i]

                        if groupsize != -1:
                            if not static_groups:
                                if (i1 + i) % groupsize == 0:
                                    self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
                            else:
                                idx = i1 + i
                                if actorder:
                                    idx = perm[idx]
                                self.quantizer = groups[idx // groupsize]

                        q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
                        Q1[:, i] = q
                        Losses1[:, i] = (w - q) ** 2 / d ** 2

                        err1 = (w - q) / d
                        W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
                        Err1[:, i] = err1

                    Q[:, i1:i2] = Q1
                    Losses[:, i1:i2] = Losses1 / 2

                    W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

                torch.cuda.synchronize()

                if actorder:
                    Q = Q[:, invperm]

                to_be_saved = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
                if torch.any(torch.isnan(to_be_saved)):
                    logging.warning('NaN in weights, now percdamp is %f', current_percdamp)
                    current_percdamp *= 2  # 将percdamp翻倍
                    max_retries -= 1
                    if max_retries == 0:
                        raise ValueError('NaN in weights after all retries')
                    continue  # 重试量化过程
                
                # 量化成功，保存结果
                self.layer.weight.data = to_be_saved
                return  # 成功完成，退出函数
                
            except Exception as e:
                if max_retries == 0:
                    raise e
                current_percdamp *= 2
                max_retries -= 1
                continue

    def free(self):
        self.H = None
        self.Losses = None
        self.Trace = None
        torch.cuda.empty_cache()
        cleanup_memory(verbose=False)
        
        
@torch.no_grad()
def gptq_fwrd(model, dataloader, dev, args):
    '''
    From GPTQ repo 
    TODO: Make this function general to support both OPT and LLaMA models
    '''
    logging.info('-----GPTQ Quantization-----')
    model.to("cpu")
    
    layers = model.transformer_blocks

    set_embed_to_device(model, dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = {
        "hidden_states": [],
        "encoder_hidden_states": [],
    }
    cache = {"i": 0,
             "temb": []}
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, **kwargs):
            # import ipdb; ipdb.set_trace()
            # inps[cache["i"]] = inp
            inps["hidden_states"].append(kwargs["hidden_states"]) # to tuple
            inps["encoder_hidden_states"].append(kwargs["encoder_hidden_states"]) # to tuple
            cache["i"] += 1
            cache["image_rotary_emb"] = kwargs["image_rotary_emb"]
            cache["temb"].append(kwargs["temb"]) # move this to cache
            # cache["attention_kwargs"] = kwargs["attention_kwargs"]
            raise ValueError
        
    layers[0] = Catcher(layers[0])
    with torch.no_grad():
        for batch in dataloader:
            if cache["i"] >= args.nsamples:
                break
            try:
                input_dict = {}
                for key in batch.keys():
                    if torch.is_tensor(batch[key]):
                        input_dict[key] = batch[key].to(dev, dtype=dtype)
                    else:
                        input_dict[key] = batch[key]
                model(**input_dict)
            except ValueError:
                pass
    image_rotary_emb = cache["image_rotary_emb"]
    
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    set_embed_to_device(model, "cpu")
    del dataloader
    torch.cuda.empty_cache()

    outs = {
        "hidden_states": [],
        "encoder_hidden_states": [],
    }   # take output of fp model as input

    # quantizers = {}
    sequential = [
                ['norm1.linear.linear'],
                ['attn1.to_q.linear', 'attn1.to_k.linear', 'attn1.to_v.linear'],
                ['attn1.to_out.0.linear'],
                ['norm1.linear.linear'],
                ['ff.net.0.proj.linear'],
                ['ff.net.2.linear']
            ]

    # 用于存储量化后的权重
    quantized_weights = {}
    
    for i in range(len(layers)):
        print(f'\nLayer {i}:', flush=True, end=' ')
        logging.info(f'Layer {i}:')
        dtype_dict = {}
        layer = layers[i].to(dev)
        '''for name, param in layer.named_parameters():
            dtype_dict[name] = param.dtype
        with torch.no_grad():
            layer.float()'''
        full = find_qlayers(layer, layers=[torch.nn.Linear])
        for names in sequential:
            # import ipdb; ipdb.set_trace()
            subset = {n: full[n] for n in names}

            gptq = {}
            for name in subset:
                print(f'{name}', end='  ', flush=True)
                layer_weight_bits = args.w_bits
                layer_weight_sym = not(args.w_asym)
                
                gptq[name] = GPTQ(subset[name])
                gptq[name].quantizer = WeightQuantizer()
                gptq[name].quantizer.configure(
                    layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.gptq_mse
                )

            def add_batch(name):
                def tmp(_, inp, out):
                    gptq[name].add_batch(inp[0].data, out.data)
                return tmp
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(args.nsamples):
                _, _ = layer(hidden_states=inps["hidden_states"][j],
                                   encoder_hidden_states=inps["encoder_hidden_states"][j],
                                   temb=cache["temb"][j],
                                   image_rotary_emb=image_rotary_emb,
                                   )
                '''hidden_states, encoder_hidden_states = layer(hidden_states=inps["hidden_states"][j].float(),
                                   encoder_hidden_states=inps["encoder_hidden_states"][j].float(),
                                   temb=cache["temb"][j],
                                   image_rotary_emb=image_rotary_emb,
                                   )
                outs["hidden_states"].append(hidden_states)
                outs["encoder_hidden_states"].append(encoder_hidden_states)'''
            for h in handles:
                h.remove()

            for name in subset:
                layer_w_groupsize = args.w_groupsize
                gptq[name].fasterquant(
                    percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order, static_groups=False
                )
                # quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
                gptq[name].free()              

        for j in range(args.nsamples):
            hidden_states, encoder_hidden_states = layer(hidden_states=inps["hidden_states"][j],
                                   encoder_hidden_states=inps["encoder_hidden_states"][j],
                                   temb=cache["temb"][j],
                                   image_rotary_emb=image_rotary_emb,
                                   )
            outs["hidden_states"].append(hidden_states)
            outs["encoder_hidden_states"].append(encoder_hidden_states)

        '''for name, param in layer.named_parameters():
            param.requires_grad = False
            if name in dtype_dict.keys():
                param.data = param.to(dtype_dict[name])'''
        layers[i] = layer.cpu()
        del layer
        del gptq 
        torch.cuda.empty_cache()

        inps = outs
        outs = {
            "hidden_states": [],
            "encoder_hidden_states": [],
        }

    for i in range(len(layers)):
        for names in sequential:
            full = find_qlayers(layers[i], layers=[torch.nn.Linear])
            subset = {n: full[n] for n in names}
            for name in subset:
                quantized_weights[f'layer_{i}_{name}'] = subset[name].weight.data.clone()
    
    cleanup_memory(verbose=True)
    logging.info('-----GPTQ Quantization Done-----\n')
    return quantized_weights

def save_quantized_weights(weights, save_path):
    """
    保存量化后的权重到文件
    
    Args:
        weights: 量化后的权重字典
        save_path: 保存路径
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(weights, save_path)
    logging.info(f'Quantized weights saved to {save_path}')

def load_quantized_weights(model, weights_path):
    """
    从文件加载量化后的权重到模型
    
    Args:
        model: 模型实例
        weights_path: 权重文件路径
    """
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"权重文件不存在: {weights_path}")
        
    weights = torch.load(weights_path)
    layers = model.transformer_blocks
    
    sequential = [
        ['norm1.linear.linear'],
        ['attn1.to_q.linear', 'attn1.to_k.linear', 'attn1.to_v.linear'],
        ['attn1.to_out.0.linear'],
        ['norm1.linear.linear'],
        ['ff.net.0.proj.linear'],
        ['ff.net.2.linear']
    ]
    
    for i in range(len(layers)):
        layer = layers[i]
        full = find_qlayers(layer, layers=[torch.nn.Linear])
        
        for names in sequential:
            subset = {n: full[n] for n in names}
            for name in subset:
                weight_key = f'layer_{i}_{name}'
                if weight_key in weights:
                    subset[name].weight.data.copy_(weights[weight_key])
                else:
                    logging.warning(f'未找到权重: {weight_key}')
    
    logging.info(f'量化权重已从 {weights_path} 加载完成')


@torch.no_grad()
def rtn_fwrd(model, dev, args):
    '''
    From GPTQ repo 
    TODO: Make this function general to support both OPT and LLaMA models
    '''
    assert args.w_groupsize ==-1, "Groupsize not supported in RTN!"
    layers = model.model.layers
    torch.cuda.empty_cache()

    quantizers = {}

    for i in tqdm.tqdm(range(len(layers)), desc="(RtN Quant.) Layers"):
        layer = layers[i].to(dev)

        subset = find_qlayers(layer,
                                            layers=[torch.nn.Linear])

        for name in subset:
            layer_weight_bits = args.w_bits
            if 'lm_head' in name:
                layer_weight_bits = 16
                continue

            quantizer = WeightQuantizer()
            quantizer.configure(
                layer_weight_bits, perchannel=True, sym=not(args.w_asym), mse=args.gptq_mse
            )
            W = subset[name].weight.data
            w_dtype = W.dtype
            quantizer.find_params(W)
            subset[name].weight.data = quantizer.quantize(W).to(w_dtype)
            quantizers['model.layers.%d.%s' % (i, name)] = quantizer.cpu()
        layers[i] = layer.cpu()
        torch.cuda.empty_cache()
        del layer
            
    cleanup_memory(verbose=True)
    return quantizers
