import torch
import torch.nn as nn
from omniquant.int_llama_layer import QuantLlamaDecoderLayer
from omniquant.int_opt_layer import QuantOPTDecoderLayer
from omniquant.int_falcon_layer import QuantFalconDecoderLayer
from omniquant.int_mistral_layer import QuantMistralDecoderLayer
from omniquant.int_linear import QuantLinear
from contextlib import nullcontext
import copy
import math
import omniquant.omni_utils as omni_utils
import os
import pdb
import gc
import os
import sys
import random
import numpy as np
import time
from pprint import pprint
import torch.nn as nn
from omniquant.LMClass import LMClass
from utils.calib import *
import pdb


torch.backends.cudnn.benchmark = True
from omniquant.omni_utils import let_parameters, lwc_parameters, get_omni_parameters, \
    omni_state_dict, register_scales_and_zeros, smooth_and_quant_temporary, \
    smooth_and_quant_inplace, clear_temp_variable, set_quant_state
try:
    import auto_gptq.nn_modules.qlinear.qlinear_cuda as qlinear_cuda
    import auto_gptq.nn_modules.qlinear.qlinear_triton as qlinear_triton
except:
    print("auto_gptq is required for real quantization")
from utils.common import *


def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, QuantLinear)}


def add_new_module(name, original_module, added_module):
    levels = name.split('.')
    if len(levels) > 1:
        mod_ = original_module
        for l_idx in range(len(levels) - 1):
            if levels[l_idx].isdigit():
                mod_ = mod_[int(levels[l_idx])]
            else:
                mod_ = getattr(mod_, levels[l_idx])
        setattr(mod_, levels[-1], added_module)
    else:
        setattr(original_module, name, added_module)


def omniquant(
    lm,
    args,
    dataloader,
    act_scales,
    act_shifts,
):
    logging.info("Starting ...")

    # move embedding layer and first layer to target device
    model = lm.model
    dev = lm.device
    use_cache = model.config.use_cache
    model.config.use_cache = False
    is_llama = False
    if "llama" in args.net.lower():
        is_llama = True
        layers = model.model.layers
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.norm = model.model.norm.to(dev)
        DecoderLayer = QuantLlamaDecoderLayer
        pairs = {
            "q_proj": "qkv",
            "o_proj": "out",
            "up_proj": "fc1"
        }
        layer_name_prefix = "model.layers"
    elif "mistral" in args.net.lower():
        is_llama = True
        layers = model.model.layers
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.norm = model.model.norm.to(dev)
        DecoderLayer = QuantMistralDecoderLayer
        pairs = {
            "q_proj": "qkv",
            "o_proj": "out",
            "up_proj": "fc1"
        }
        layer_name_prefix = "model.layers"
    elif "opt" in args.net.lower():
        layers = model.model.decoder.layers
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
        if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
        if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
        DecoderLayer = QuantOPTDecoderLayer
        pairs = {
            "q_proj": "qkv",
            "out_proj": "out",
            "fc1": "fc1"
        }
        layer_name_prefix = "model.decoder.layers"
    elif "falcon" in args.net.lower():
        layers = model.transformer.h
        model.transformer.word_embeddings.to(dev)
        model.transformer.ln_f.to(dev)
        model.lm_head.to(dev)
        DecoderLayer = QuantFalconDecoderLayer
        layer_name_prefix = "model.transformer.h"
    elif 'mixtral' in args.net.lower():
        is_llama = True   # same to llama except ffn
        layers = model.model.layers
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.norm = model.model.norm.to(dev)
        layer_name_prefix = "model.layers"
    else:
        raise ValueError("Only support for opt/llama/Llama-2/falcon/mixtral now")

    layers[0] = layers[0].to(dev)
    if args.deactive_amp and args.epochs > 0:
        dtype = torch.float
        traincast = nullcontext
    else:
        dtype = torch.float16
        traincast = torch.cuda.amp.autocast
    inps = torch.zeros(
        (args.omni_cal_nsamples, lm.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0}

    # catch the first layer input
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
            self.is_llama = False

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            if self.is_llama:
                cache["position_ids"] = kwargs["position_ids"]
            raise ValueError

    layers[0] = Catcher(layers[0])
    layers[0].is_llama = is_llama

    with torch.no_grad():
        for batch in dataloader:
            if cache["i"] >= args.omni_cal_nsamples:
                break
            try:
                model(batch[0].to(dev))
            except ValueError:
                pass

    # move embedding layer and first layer to cpu
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()
    if "llama" in args.net.lower() or "mixtral" in args.net.lower() or "mistral" in args.net.lower():
        model.model.embed_tokens = model.model.embed_tokens.cpu()
        model.model.norm = model.model.norm.cpu()
    elif "opt" in args.net.lower():
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
        if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.cpu()
        if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.cpu()
    elif 'falcon' in args.model:
        model.transformer.word_embeddings = model.transformer.word_embeddings.cpu()
    else:
        raise ValueError("Only support for opt/llama/Llama-2/falcon/mixtral now")
    torch.cuda.empty_cache()

    # same input of first layer for fp model and quant model
    quant_inps = inps
    fp_inps = copy.deepcopy(inps)   # take output of fp model as input
    fp_inps_2 = copy.deepcopy(inps) if args.aug_loss else None  # take output of quantization model as input

    attention_mask = cache["attention_mask"]

    if attention_mask is not None:
        attention_mask_batch = attention_mask.repeat(
            args.omni_batch_size, 1, 1, 1) if args.deactive_amp else attention_mask.repeat(args.omni_batch_size, 1, 1, 1).float()
    else:
        logging.info(
            "No attention mask caught from the first layer."
            " Seems that model's attention works without a mask.")
        attention_mask_batch = None

    loss_func = torch.nn.MSELoss()
    if is_llama:
        position_ids = cache["position_ids"]
    else:
        position_ids = None

    if args.omni_resume:
        omni_parameters = torch.load(args.omni_resume)
    else:
        omni_parameters = {}

    for i in range(len(layers)):
        logging.info(f"=== Start quantize layer {i} ===")
        layer = layers[i].to(dev)
        if "mixtral" in args.net.lower():
            # for mixtral, we only leverage lwc, which can be achieve by simply replace Linear with QuantLinear
            qlayer = copy.deepcopy(layer)
            for name, module in qlayer.named_modules():
                if isinstance(module, torch.nn.Linear) and not "gate" in name:       # do not quantize gate
                    quantlinear = QuantLinear(module, args.weight_quant_params, args.act_quant_params)
                    add_new_module(name, qlayer, quantlinear)
        else:
            qlayer = DecoderLayer(lm.model.config, layer, args)
        qlayer = qlayer.to(dev)

        # obtain output of full-precision model
        set_quant_state(qlayer, weight_quant=False, act_quant=False)
        if args.epochs > 0:
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    for j in range(args.omni_cal_nsamples):
                        fp_inps[j] = qlayer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask,
                                            position_ids=position_ids)[0]
                        if args.aug_loss:
                            fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(
                                0), attention_mask=attention_mask, position_ids=position_ids)[0]
        # init smooth parameters
        set_quant_state(qlayer, weight_quant=False, act_quant=True)  # weight will be manually quantized before forward
        qlayer.let = args.let
        use_shift = True
        if is_llama or args.abits == 16:
            use_shift = False                   # deactivate channel-wise shifting for llama model and weight-only quantization
        if args.let:
            # init channel-wise scaling and shift
            qlayer.register_parameter("qkt_smooth_scale", torch.nn.Parameter(
                torch.ones(layer.self_attn.q_proj.out_features, device=dev, dtype=dtype)))
            for name, module in qlayer.named_modules():
                if isinstance(module, QuantLinear):
                    for key in pairs.keys():
                        if key in name:
                            act = act_scales[f"{layer_name_prefix}.{i}.{name}"].to(
                                device=dev, dtype=dtype).clamp(min=1e-5)
                            weight = module.weight.abs().max(dim=0)[0].clamp(min=1e-5)
                            scale = (act.pow(args.smooth_quant_alpha) /
                                     weight.pow(1 - args.smooth_quant_alpha)).clamp(min=1e-5)
                            if use_shift and not is_llama:
                                shift = act_shifts[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype)
                            else:
                                shift = torch.zeros_like(scale)
                            qlayer.register_parameter(f"{pairs[key]}_smooth_shift", torch.nn.Parameter(shift))
                            qlayer.register_parameter(f"{pairs[key]}_smooth_scale", torch.nn.Parameter(scale))

        if args.omni_resume:
            qlayer.load_state_dict(omni_parameters[i], strict=False)

        if args.epochs > 0:
            with torch.no_grad():
                qlayer.float()      # required for AMP training
            # create optimizer
            optimizer = torch.optim.AdamW(
                [{"params": let_parameters(qlayer, use_shift), "lr": args.let_lr}, {"params": lwc_parameters(qlayer), "lr": args.lwc_lr}], weight_decay=args.wd)

            # 添加余弦退火学习率调度器
            total_steps = args.epochs * (args.omni_cal_nsamples // args.omni_batch_size)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=total_steps,
                eta_min=1e-7  # 最小学习率
            )

            loss_scaler = omni_utils.NativeScalerWithGradNormCount()

            for epochs in range(args.epochs):
                loss_list = []
                norm_list = []
                for j in range(args.omni_cal_nsamples // args.omni_batch_size):
                    index = j * args.omni_batch_size
                    # obtain output of quantization model
                    with traincast():
                        smooth_and_quant_temporary(qlayer, args, is_llama)
                        quant_out = qlayer(quant_inps[index:index + args.omni_batch_size,],
                                           attention_mask=attention_mask_batch, position_ids=position_ids)[0]
                        loss = loss_func(fp_inps[index:index + args.omni_batch_size,], quant_out)
                        if args.aug_loss:
                            loss += loss_func(fp_inps_2[index:index + args.omni_batch_size,], quant_out)
                    if not math.isfinite(loss.item()):
                        logging.info("Loss is NAN, stopping training")
                        pdb.set_trace()

                    loss_list.append(loss.detach().cpu())
                    optimizer.zero_grad()
                    norm = loss_scaler(loss, optimizer, parameters=get_omni_parameters(qlayer, use_shift)).cpu()
                    norm_list.append(norm.data)

                    # 更新学习率调度器
                    scheduler.step()

                loss_mean = torch.stack(loss_list).mean()
                norm_mean = torch.stack(norm_list).mean()
                current_lr = scheduler.get_last_lr()[0]  # 获取当前学习率
                logging.info(
                    f"layer {i} iter {epochs} loss:{loss_mean} norm:{norm_mean} lr:{current_lr:.2e} max memory_allocated {torch.cuda.max_memory_allocated(lm._device) / 1024**2} ")
            clear_temp_variable(qlayer)
            del optimizer
            del scheduler
        qlayer.half()
        # real smooth and quantization
        smooth_and_quant_inplace(qlayer, args, is_llama)
        if args.epochs > 0:
            # update input of quantization model
            with torch.no_grad():
                # with torch.cuda.amp.autocast():
                with traincast():
                    for j in range(args.omni_cal_nsamples):
                        quant_inps[j] = qlayer(quant_inps[j].unsqueeze(
                            0), attention_mask=attention_mask, position_ids=position_ids)[0]
            # register_scales_and_zeros(qlayer)
            layers[i] = qlayer.to("cpu")
            omni_parameters[i] = omni_state_dict(qlayer)
            os.makedirs(args.omni_save_dir, exist_ok=True)
            torch.save(omni_parameters, os.path.join(args.omni_save_dir, f"omni_parameters.pth"))
        else:
            # register_scales_and_zeros(qlayer)
            layers[i] = qlayer.to("cpu")
        del layer
        torch.cuda.empty_cache()

    del inps
    del quant_inps
    del fp_inps
    del fp_inps_2
    torch.cuda.empty_cache()
    gc.collect()
    model.config.use_cache = use_cache
    return model


def main_omniquant(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # check
    if args.epochs > 0:
        assert args.lwc or args.let

    args.deactive_amp = False
    args.net = args.model.split('/')[-1]
    lm = LMClass(args)
    lm.seqlen = args.omni_cal_seqlen
    lm.model.eval()
    for param in lm.model.parameters():
        param.requires_grad = False

    args.weight_quant_params = {
        "dynamic_method": 'per_channel',
        "lwc": args.lwc,
        "args": args,
        "mtype": "linear",
    }
    args.act_quant_params = {
        "dynamic_method": 'per_token',
        "args": args,
        "mtype": "linear",
        "btype": "a",
    }
    args.q_quant_params = {
        "args": args,
        "mtype": "matmul",
        "btype": "A",
        "dynamic_method": 'per_token',
    }
    args.k_quant_params = {
        "args": args,
        "mtype": "matmul",
        "btype": "B",
        "dynamic_method": 'per_token',
    }
    args.v_quant_params = {
        "args": args,
        "mtype": "matmul",
        "btype": "B",
        "dynamic_method": 'per_token',
    }
    args.p_quant_params = {
        "args": args,
        "mtype": "",
        "btype": "",
    }

    distribute_model(lm.model)

    # act scales and shifts
    setattr(args, "act_scales", f'./omniquant/act_scales/{args.net}.pt')
    setattr(args, "act_shifts", f'./omniquant/act_scales/{args.net}.pt')

    # quantization
    logging.info("=== start quantization ===")
    tick = time.time()
    dataloader = get_loaders(
        args.omni_cal_dataset,
        nsamples=args.omni_cal_nsamples,
        seed=args.seed,
        model=args.model,
        seqlen=lm.seqlen,
        eval_mode=False,
    )
    act_scales = None
    act_shifts = None
    if args.let:
        act_scales = torch.load(args.act_scales)
        act_shifts = torch.load(args.act_shifts)
    omniquant(
        lm,
        args,
        dataloader,
        act_scales,
        act_shifts,
    )
    logging.info(time.time() - tick)

    return lm
