from transformers.utils.bitsandbytes import *
from transformers import BitsAndBytesConfig
import torch
from torch import  nn
import bitsandbytes as bnb
from lavin import Tokenizer
import lavin.eval_model
from .int_llama_layer import QuantTransformerBlock
from .int_linear import QuantLinear
import copy
from fairscale.nn.model_parallel.layers import (
    ParallelEmbedding,
    RowParallelLinear,
    ColumnParallelLinear,
)
from contextlib import nullcontext
from .misc import NativeScalerWithGradNormCount,create_logger
from .datautils import get_loaders
import math
import pdb
import time
import json
import numpy as np

def _replace_with_bnb_linear(
    model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """
    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)

        if (isinstance(module, nn.Linear) or isinstance(module, ColumnParallelLinear)  or isinstance(module, RowParallelLinear)  ) and name not in modules_to_not_convert:
            # Check if the current key is not in the `modules_to_not_convert`
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
                # with init_empty_weights():
                if quantization_config.quantization_method() == "llm_int8":
                    model._modules[name] = bnb.nn.Linear8bitLt(
                        module.in_features,
                        module.out_features,
                        module.bias is not None,
                        has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
                        threshold=quantization_config.llm_int8_threshold,
                    )
                    has_been_replaced = True
                else:
                    if (
                        quantization_config.llm_int8_skip_modules is not None
                        and name in quantization_config.llm_int8_skip_modules
                    ):
                        pass
                    else:
                        model._modules[name] = bnb.nn.Linear4bit(
                            module.in_features,
                            module.out_features,
                            module.bias is not None,
                            quantization_config.bnb_4bit_compute_dtype,
                            compress_statistics=quantization_config.bnb_4bit_use_double_quant,
                            quant_type=quantization_config.bnb_4bit_quant_type,
                        )
                        has_been_replaced = True
                # Force requires grad to False to avoid unexpected errors
                model._modules[name].requires_grad_(False)
        if len(list(module.children())) > 0:
            _, has_been_replaced = _replace_with_bnb_linear(
                module,
                modules_to_not_convert,
                current_key_name,
                quantization_config,
                has_been_replaced=has_been_replaced,
            )
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model, has_been_replaced


def quant_model_bnb(model, quant_bit='4bit', keep_in_fp32_modules=[],
                    quantization_config=None):
    if quantization_config is None:
        # set default quantization config
        # compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=quant_bit == '4bit',
            load_in_8bit=quant_bit == '8bit',
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4',
        )
    model,_ = _replace_with_bnb_linear(
        model, modules_to_not_convert=keep_in_fp32_modules, quantization_config=quantization_config
    )

    return model

def omniquant(model, args):
    dev = torch.device(args.device)
    
    if args.abits < 16 and args.calib_epochs > 0:
        dtype = torch.float
        traincast = nullcontext
    else:
        dtype = torch.float16
        traincast = torch.cuda.amp.autocast
    print("==== weight quant params ====")
    print(args.weight_quant_params)
    print("==== act quant params ====") 
    print(args.act_quant_params)
    print(args)

    if args.calib_epochs > 0: 
        logger = create_logger(args.output_dir)
        logger.info(args)
        logger.info(dev)
        '''
        load calibration dataset
        '''

        cache_dataloader = f'{args.cache_dir}/dataloader_{args.model_family}_{args.calib_dataset}_{args.nsamples}.cache'
        import os
        if os.path.exists(cache_dataloader):
            dataloader = torch.load(cache_dataloader)
            logger.info(f"load calibration from {cache_dataloader}")
        else:
            dataloader, _ = get_loaders(
                args.calib_dataset,
                nsamples=args.nsamples,
                seed=args.seed,
                model=args.llama_model_path,
                seqlen=model.params.max_seq_len,
                args = args,
                )
            torch.save(dataloader, cache_dataloader)    
        '''
        obatain input activation
        '''
        logger.info("Experiment:{}_{}_{}_{}".format(args.calib_dataset,args.nsamples,args.model_family,cache_dataloader))
        
        tok_emb = model.tok_embeddings.to(dev)
        inps = torch.zeros(
            (args.nsamples, model.params.max_seq_len, model.params.dim), dtype=dtype
        ).to(dev)

        k = 0
        start_poses = []
        freqs_cises=[]
        masks=[]
        omni_parameters = {}            
        with torch.no_grad():
            for batch in dataloader:
                inps[k] = tok_emb(batch[0].to(dev))
                k += 1 
                start_pos = 0
                start_poses.append(start_pos)
                
                _bsz, seqlen = batch[0].shape
                freqs_cis = model.freqs_cis[:seqlen].to(dev)
                freqs_cises.append(freqs_cis)

                mask = None
                mask = torch.full((1, 1, seqlen, seqlen), float("-inf")).to(dev)
                mask = torch.triu(mask, diagonal=0 + 1).type_as(inps)
                mask[:,:,1:,0]=float("-inf")
                masks.append(mask)
        fp_inps = copy.deepcopy(inps) 
        quant_inps = copy.deepcopy(inps) 

    elif args.quant_resume:
        '''
        load pretrained quant parameters.
        '''
        omni_parameters = torch.load(args.quant_resume)

    
    layers = model.layers
    for i in range(len(layers)):
        qlayer = QuantTransformerBlock(layers[i], args)
        qlayer = qlayer.to(dev) 
        qlayer.set_quant_state(weight_quant=False, act_quant=False)

        # block-wise construction
        if args.calib_epochs > 0:
            logger.info(f"=== Start quantize layer {i} ===")
            with torch.no_grad():
                with traincast():
                    for j in range(args.nsamples):
                        fp_inps[j] = qlayer(fp_inps[j].unsqueeze(0),
                                            start_poses[j], freqs_cises[j], masks[j])
                        if torch.isnan(fp_inps[j]).any():
                            logger.info("fp output is NAN, stopping training")
                            pdb.set_trace()
        
            with torch.no_grad():
                qlayer.float() 
            loss_func = torch.nn.MSELoss()
            optimizer = torch.optim.AdamW(
                [ {"params":qlayer.lwc_parameters(),"lr":args.lwc_lr}],weight_decay=args.wd)
            loss_scaler = NativeScalerWithGradNormCount()
    
            for epoch in range(args.calib_epochs):
                loss_list = []
                norm_list = []
                for j in range(args.nsamples):    
                    # obtain output of quantization model
                    with traincast():
                        qlayer.smooth_and_quant_temporary() # activate weight quantization
                        quant_out = qlayer(quant_inps[j].unsqueeze(0), 
                                        start_poses[j], freqs_cises[j], masks[j])
                        loss = loss_func(quant_out,fp_inps[j])
                        if torch.isnan(quant_out).any():
                            logger.info("quant output is NAN, stopping training")
                            pdb.set_trace()
                    if not math.isfinite(loss.item()):
                        logger.info("Loss is NAN, stopping training")
                        pdb.set_trace()
                        
                    loss_list.append(loss.data)
                    optimizer.zero_grad()
                    norm = loss_scaler(loss, optimizer,parameters=qlayer.lwc_parameters())
                    norm_list.append(norm.data)
                loss_mean = torch.stack(loss_list).mean()
                norm_mean = torch.stack(norm_list).mean()
                # logger.info(f"layer {i} iter {epoch} loss:{loss_mean} ")
                logger.info(f"layer {i} iter {epoch} loss:{loss_mean} norm:{norm_mean} max memory_allocated {torch.cuda.max_memory_allocated(dev) / 1024**2} ")
            qlayer.clear_temp_variable()
            del optimizer 
            qlayer.half()

            if args.start_layer > len(layers):
                qlayer.smooth_and_quant_inplace() 
            else:
                qlayer.smooth_and_quant_fake()
            with torch.no_grad():
                with traincast():
                    for j in range(args.nsamples):
                        quant_inps[j] = qlayer(quant_inps[j].unsqueeze(0), start_poses[j], freqs_cises[j], masks[j])
            if i > args.start_layer:
                qlayer.add_scaling()
            layers[i] = qlayer.to("cpu")
            omni_parameters[i] = qlayer.omni_state_dict()
            torch.save(omni_parameters, os.path.join(args.output_dir, f"omni_parameters_{args.calib_dataset}_{args.nsamples}.pth"))
        else:
            # load pretrained quantization parameters
            print(f"=== Start loading layer {i} ===")
            qlayer.load_state_dict(omni_parameters[i],strict=False) 
            print(f"check layer {i} quantization parameters:{qlayer.attention.wq.weight_quantizer.upbound_factor}")
            qlayer.half()
            qlayer.smooth_and_quant_fake()
            # qlayer.smooth_and_quant_inplace()
            if i > args.start_layer:
                qlayer.add_scaling()
            layers[i] = qlayer.to("cpu") 
    
    if args.scaling_resume:
        print("loading scaling parameters")
        ckpt, state_dict= torch.load(args.scaling_resume)['model'], {}
        for key in ckpt:
            state_dict[key.replace('module.','')]=ckpt[key]  
        model.load_state_dict(state_dict, strict=False)
    setattr(model,'layers',layers)
    if args.calib_epochs>0:
        del inps
        del quant_inps
        del fp_inps

    torch.cuda.empty_cache()
    

def quant_model_omni(model,args):
    for name, module in model.named_children():
        if (isinstance(module, nn.Linear) or isinstance(module, ColumnParallelLinear)  or isinstance(module, RowParallelLinear)):
            model._modules[name] = QuantLinear(module,args.weight_quant_params,args.act_quant_params)
            model._modules[name].requires_grad_(False)
        if len(list(module.children())) > 0:
            _ = quant_model_omni(module,args)
    # 载入权重
    return model

def omniquant_from_checkpoint(
    model,
    args,
):
    print("Starting ...")
    layers = model.layers
    if args.resume:
        omni_parameters = torch.load(args.resume)
    else:
        omni_parameters = {}

    
    for i in range(len(layers)):
        print(f"=== Start loading layer {i} ===")
        layer = layers[i]
        qlayer = QuantTransformerBlock(layer, args)           
        if args.resume:
            qlayer.load_state_dict(omni_parameters[i], strict=False)
            

        # real smooth and quantization
        if args.scaling_resume:
            if i > args.start_layer:
                qlayer.add_scaling()
        layers[i] = qlayer

    if args.scaling_resume:
        ckpt, state_dict= torch.load(args.scaling_resume)['model'], {}
        for key in ckpt:
            state_dict[key.replace('module.','')]=ckpt[key]  
        model.load_state_dict(state_dict, strict=False)
    for i in range(len(layers)):
        qlayer = layers[i].to("cuda")
        qlayer.smooth_and_quant_inplace()
        layers[i] = qlayer.to("cpu")
