import torch
import torch.nn as nn
from models.int_llama_layer import QuantLlamaDecoderLayer
from models.int_opt_layer import QuantOPTDecoderLayer
from models.int_falcon_layer import QuantFalconDecoderLayer
from quantize.int_linear import QuantLinear
from quantize.layerwrapper import WrappedGPT
# import auto_gptq.nn_modules.qlinear.qlinear_cuda as qlinear_cuda
from contextlib import nullcontext
from datasets import load_dataset
import random
import copy
import math
import utils
import os
import pdb
import gc

def check_sparsity(qlayer):
    subset = find_layers(qlayer)
        sub_count = 0
        count = 0
        total_params = 0
        sub_params = 0
        for name in subset:
            W = subset[name].weight.data
            count += (W == 0).sum().item()
            total_params += W.numel()

            sub_count += (W == 0).sum().item()
            sub_params += W.numel()
            print(f"{name} sparsity {float((W == 0).sum().item())/W.numel():.6f}")
        print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")

def return_reorder_indice(input_tensor):
    """
    For instance:
    [[1., -2., 3.],
    [-2, 2., -4],
    [5., 6., -7],
    [-6, -7, -4]]
    return indices of
    [[-2.,  3.,  1.],
    [-2., -4.,  2.],
    [-7.,  6.,  5.],
    [-6., -7., -4.]]
    Description: 输入一个数组, 返回一个数组的重排序号,
    使得重排序后的数组的每一行都是按照负数在前, 正数在后的顺序,
    并且正数内的相对顺序不变, 负数内的相对顺序翻转
    """
    positive_tensor = input_tensor.clone()
    negative_tensor = input_tensor.clone()

    positive_mask = positive_tensor > 0
    negative_mask = negative_tensor < 0

    positive_indices = (
        torch.arange(0, input_tensor.shape[1], device=input_tensor.device)
        .to(torch.float64)
        .repeat(input_tensor.shape[0], 1)
    )
    negative_indices = (
        torch.arange(0, input_tensor.shape[1], device=input_tensor.device)
        .to(torch.float64)
        .repeat(input_tensor.shape[0], 1)
    )

    positive_indices[~positive_mask] = float("inf")
    negative_indices[~negative_mask] = float("inf")

    positive_value, _ = torch.sort(positive_indices, dim=1)
    negative_value, _ = torch.sort(negative_indices, dim=1)

    positive_value = torch.flip(positive_value, dims=[1])

    negative_value[negative_value == float("inf")] = 0
    positive_value[positive_value == float("inf")] = 0

    reorder_indice = (positive_value + negative_value).to(torch.int64)

    return reorder_indice

def return_flip_par_reorder_indice(input_tensor):
    """
        For instance:
        [[-2.,  1.,  3.],
        [-4., -2.,  2.],
        [-7.,  5.,  6.],
        [-7., -6., -4.]]
        return indices of
        [[-2.,  3.,  1.],
        [-2., -4.,  2.],
        [-7.,  6.,  5.],
        [-4., -6., -7.]]
        Description: 输入一个数组, 返回一个数组的重排结果, 
        使得重排序后的数组的正数和负数内的相对顺序翻转
    """
    positive_tensor = input_tensor.clone()
    negative_tensor = input_tensor.clone()

    positive_mask = positive_tensor > 0
    negative_mask = negative_tensor < 0

    positive_indices = torch.arange(0, input_tensor.shape[1], device = input_tensor.device).to(torch.float64).repeat(input_tensor.shape[0],1)
    negative_indices = torch.arange(0, input_tensor.shape[1], device = input_tensor.device).to(torch.float64).repeat(input_tensor.shape[0],1)

    positive_indices[~positive_mask] = float('inf')
    negative_indices[~negative_mask] = float('-inf')

    positive_value, _ = torch.sort(positive_indices, dim=1)
    negative_value, _ = torch.sort(negative_indices, dim=1)

    positive_value = torch.flip(positive_value, dims=[1])
    negative_value = torch.flip(negative_value, dims=[1])

    positive_value[positive_value == float('inf')] = 0
    negative_value[negative_value == float('-inf')] = 0

    reorder_indice = (positive_value + negative_value).to(torch.int64)

    return reorder_indice

def find_layers(module, layers=[QuantLinear], name=""):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            find_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res

def get_wanda_mask(args, qlayer, ds_outs, attention_mask = None, position_ids = None):
    prune_n, prune_m = 0, 0
    if args.sparsity_type != "unstructured":
        assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
        prune_n, prune_m = map(int, args.sparsity_type.split(":"))

    print(f"sparsity_ratio"{args.sparsity_ratio})
    subset = find_layers(qlayer)
    wrapped_layers = {}
    for name in subset:
        wrapped_layers[name] = WrappedGPT(subset[name])

    def add_batch(name):
        def tmp(_, inp, out):
            wrapped_layers[name].add_batch(inp[0].data, out.data)

        return tmp

    handles = []
    for name in wrapped_layers:
        handles.append(subset[name].register_forward_hook(add_batch(name)))
    for j in range(args.nsamples):
        with torch.no_grad():
            ds_outs[j] = qlayer(
                quant_inps[j].unsqueeze(0),
                attention_mask=attention_mask,
                position_ids=position_ids,
            )[0]
    for h in handles:
        h.remove()
    torch.cuda.empty_cache()
    for name in subset:
        print(f"pruning layer {i} name {name}")
        W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
            wrapped_layers[name].scaler_row.reshape((1, -1))
        )

        W_mask = (
            torch.zeros_like(W_metric) == 1
        )  ## initialize a mask to be all False
        if prune_n != 0:
            # structured n:m sparsity
            for ii in range(W_metric.shape[1]):
                if ii % prune_m == 0:
                    tmp = W_metric[:, ii : (ii + prune_m)].float()
                    W_mask.scatter_(
                        1,
                        ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
                        True,
                    )
        else:
            sort_res = torch.sort(W_metric, dim=-1, stable=True)

            # unstructured pruning
            indices = sort_res[1][:, : int(W_metric.shape[1] * args.sparsity_ratio)]
            W_mask.scatter_(1, indices, True)

            subset[name].W_mask = W_mask


    del wrapped_layers
    del W_mask
    del indices
    del W_metric
    del sort_res
    torch.cuda.empty_cache()


def get_dsnot_mask(args, qlayer, ds_outs, attention_mask = None, position_ids = None):
    prune_n, prune_m = 0, 0
    if args.sparsity_type != "unstructured":
        assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
        prune_n, prune_m = map(int, args.sparsity_type.split(":"))

    print(f"sparsity_ratio"{args.sparsity_ratio})
    args.without_DSNT = False
    subset = find_layers(qlayer)
    wrapped_layers = {}
    for name in subset:
        wrapped_layers[name] = WrappedGPT(subset[name], initial_method=args.initial_method,)

    def add_batch(name):
        def tmp(_, inp, out):
            wrapped_layers[name].add_batch(inp[0].data, out.data)

        return tmp

    handles = []
    for name in wrapped_layers:
        handles.append(subset[name].register_forward_hook(add_batch(name)))
    for j in range(args.nsamples):
        with torch.no_grad():
            ds_outs[j] = qlayer(
                quant_inps[j].unsqueeze(0),
                attention_mask=attention_mask,
                position_ids=position_ids,
            )[0]
    for h in handles:
        h.remove()
    torch.cuda.empty_cache()
    for name in subset:
        print(f"pruning layer {i} name {name}")
        DSNT_metric = subset[name].weight.data * wrapped_layers[name].sum_metric_row.reshape((1, -1))
        initial_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
            wrapped_layers[name].scaler_row.reshape((1, -1))
        )
        weight_mask = torch.zeros_like(initial_metric) == 1
        if prune_n != 0:
            # structured n:m sparsity
            for ii in range(W_metric.shape[1]):
                if ii % prune_m == 0:
                    tmp = W_metric[:, ii : (ii + prune_m)].float()
                    W_mask.scatter_(
                        1,
                        ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
                        True,
                    )
        else:
            _, sorted_initial_indice = torch.sort(
            initial_metric, dim=-1, stable=True
            )
            sparsity_num = int(initial_metric.shape[1] * args.sparsity_ratio)
            res_sparsity_num = sorted_initial_indice.shape[1] - sparsity_num

            initial_prune_indices, initial_res_indices = torch.split(
                sorted_initial_indice,
                split_size_or_sections=[sparsity_num, res_sparsity_num],
                dim=1,
            )
            if (
                name.split(".")[0] == args.skip_layer
                or name.split(".")[1] == args.skip_sub_layer
                or args.without_DSNT
            ):
                weight_mask.scatter_(1, initial_prune_indices, True)

            else:
                weight_mask.scatter_(1, initial_prune_indices, True)

                metric_for_regrowing = DSNT_metric.clone()
                wanda_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
                    wrapped_layers[name].scaler_row.reshape((1, -1))
                )

                metric_for_regrowing.scatter_(1, initial_res_indices, 0)
                reconstruction_error = torch.sum(
                    metric_for_regrowing, dim=1, keepdim=True
                )
                initialize_error_sign = torch.sign(reconstruction_error)

                if args.pow_of_var_regrowing:
                    metric_for_regrowing /= torch.pow(
                        wrapped_layers[name].var.reshape((1, -1)),
                        args.pow_of_var_regrowing,
                    )

                _, regrowing_indices_block = torch.sort(
                    metric_for_regrowing, dim=1, stable=True
                )

                wanda_metric.scatter_(1, initial_prune_indices, float("inf"))
                wanda_res_indices, _ = torch.split(
                    torch.sort(wanda_metric, dim=1, stable=True)[1],
                    split_size_or_sections=[res_sparsity_num, sparsity_num],
                    dim=1,
                )
                reorder_indice_of_pruning_indice = return_reorder_indice(
                    torch.gather(DSNT_metric, 1, wanda_res_indices)
                )
                pruning_indices_block = torch.gather(
                    wanda_res_indices, 1, reorder_indice_of_pruning_indice
                )

                indice_indice_list_for_regrowing = torch.zeros(
                    (reconstruction_error.shape[0], 2),
                    device=reconstruction_error.device,
                    dtype=torch.long,
                )
                last_one = regrowing_indices_block.shape[-1] - 1
                indice_indice_list_for_regrowing[:, 1] = last_one

                update_num_for_regrowing = torch.ones(
                    (reconstruction_error.shape[0], 2),
                    device=reconstruction_error.device,
                    dtype=torch.long,
                )
                update_num_for_regrowing[:, 1] = -1

                indice_indice_list_for_pruning = torch.zeros(
                    (reconstruction_error.shape[0], 2),
                    device=reconstruction_error.device,
                    dtype=torch.long,
                )
                last_one = pruning_indices_block.shape[-1] - 1
                indice_indice_list_for_pruning[:, 1] = last_one

                update_num_for_pruning = torch.ones(
                    (reconstruction_error.shape[0], 2),
                    device=reconstruction_error.device,
                    dtype=torch.long,
                )
                update_num_for_pruning[:, 1] = -1

                update_mask = torch.ones_like(
                    reconstruction_error, dtype=torch.bool
                )
                cycle_time = 0
                while not ( torch.all(update_mask == False) or cycle_time >= args.max_cycle_time ):
                    cycle_time += 1
                    
                    # regrowing
                    indice_of_indice_indice_list_for_regrowing = (
                        (reconstruction_error > 0).int().to(torch.int64)
                    )

                    indice_indice_for_regrowing = torch.gather(
                        indice_indice_list_for_regrowing,
                        1,
                        indice_of_indice_indice_list_for_regrowing,
                    )

                    regrowing_indice = torch.gather(
                        regrowing_indices_block,
                        1,
                        indice_indice_for_regrowing.to(torch.int64),
                    )

                    regrowing_metric = DSNT_metric.gather(
                        1, regrowing_indice.to(torch.int64)
                    )

                    indice_indice_list_for_regrowing.scatter_(
                        1,
                        indice_of_indice_indice_list_for_regrowing,
                        indice_indice_for_regrowing
                        + update_num_for_regrowing.gather(
                            1, indice_of_indice_indice_list_for_regrowing
                        ),
                    )

                    # pruning
                    indice_of_indice_indice_list_for_pruning = (
                        (reconstruction_error < 0).int().to(torch.int64)
                    )

                    indice_indice_for_pruning = torch.gather(
                        indice_indice_list_for_pruning,
                        1,
                        indice_of_indice_indice_list_for_pruning,
                    )

                    pruning_indice = torch.gather(
                        pruning_indices_block,
                        1,
                        indice_indice_for_pruning.to(torch.int64),
                    )

                    pruning_metric = DSNT_metric.gather(
                        1, pruning_indice.to(torch.int64)
                    )

                    indice_indice_list_for_pruning.scatter_(
                        1,
                        indice_of_indice_indice_list_for_pruning, 
                        indice_indice_for_pruning
                        + update_num_for_pruning.gather(
                            1, indice_of_indice_indice_list_for_pruning
                        ),
                    )

                    # change mask
                    reconstruction_error_after = (
                        reconstruction_error + pruning_metric - regrowing_metric
                    )

                    if args.without_same_sign:
                        update_mask = update_mask & (
                            abs(reconstruction_error) > args.update_threshold
                        )
                    else:
                        update_mask = (
                            update_mask
                            & (abs(reconstruction_error) > args.update_threshold)
                            & (
                                initialize_error_sign
                                == torch.sign(reconstruction_error_after)
                            )
                        )

                    weight_mask.scatter_(1, pruning_indice, update_mask)
                    weight_mask.scatter_(1, regrowing_indice, ~update_mask)

                    reconstruction_error += torch.where(
                        update_mask,
                        pruning_metric,
                        torch.zeros_like(pruning_metric),
                    )
                    reconstruction_error -= torch.where(
                        update_mask,
                        regrowing_metric,
                        torch.zeros_like(regrowing_metric),
                    )
            subset[name].W_mask = weight_mask
        
    del wrapped_layers
    del weight_mask
    del update_mask
    del update_num_for_pruning
    del indice_indice_list_for_pruning
    del wanda_metric
    torch.cuda.empty_cache()


class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids


def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, QuantLinear)}


def affinequant(
    lm,
    args,
    dataloader,
    act_scales,
    act_shifts,
    logger=None,
):
    logger.info("Starting ...")
    
    # move embedding layer and first layer to target device
    model = lm.model
    dev = lm.device
    use_cache = model.config.use_cache
    model.config.use_cache = False
    is_llama = False
    if "llama" in args.net.lower():
        is_llama = True
        layers = model.model.layers
        model.model.embed_tokens = model.model.embed_tokens.to(dev)
        model.model.norm = model.model.norm.to(dev)
        DecoderLayer = QuantLlamaDecoderLayer
        pairs = {
            "q_proj":"qkv",
            "o_proj":"out",
            "up_proj":"fc1"
        }
        layer_name_prefix = "model.layers"
    elif "opt" in args.net.lower():
        layers = model.model.decoder.layers
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
        if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
        if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
        DecoderLayer = QuantOPTDecoderLayer
        pairs = {
            "q_proj":"qkv",
            "out_proj":"out",
            "fc1":"fc1"
        }
        layer_name_prefix = "model.decoder.layers"
    elif "falcon" in args.net.lower():
        layers = model.transformer.h
        model.transformer.word_embeddings.to(dev)
        model.transformer.ln_f.to(dev)
        model.lm_head.to(dev)
        DecoderLayer = QuantFalconDecoderLayer
        layer_name_prefix = "model.transformer.h"
    else:
        raise ValueError("Only support for opt/llama/Llama-2/falcon now")
    
    
    layers[0] = layers[0].to(dev)
    if args.deactive_amp and args.epochs>0:
        dtype = torch.float
        traincast = nullcontext
    else:
        dtype = args.dtype
        traincast = torch.cuda.amp.autocast
    inps = torch.zeros(
        (args.nsamples, lm.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0}

    model_dtype = model.dtype
    with torch.no_grad():
        model.to(args.dtype)
    # catch the first layer input
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
            self.is_llama = False

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            if self.is_llama:
                cache["position_ids"] = kwargs["position_ids"]
            raise ValueError

    layers[0] = Catcher(layers[0])
    layers[0].is_llama = is_llama

    with torch.no_grad():
        for batch in dataloader:
            if cache["i"] >= args.nsamples:
                break
            try:
                model(batch[0].to(dev))
            except ValueError:
                pass
    
    with torch.no_grad():
        model.to(model_dtype)
    # move embedding layer and first layer to cpu
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()
    if "llama" in args.net.lower():
        model.model.embed_tokens = model.model.embed_tokens.cpu()
        model.model.norm = model.model.norm.cpu()
    elif "opt" in args.net.lower():
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
        if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.cpu()
        if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.cpu()
    elif 'falcon' in args.model:
        model.transformer.word_embeddings =  model.transformer.word_embeddings.cpu()
    else:
        raise ValueError("Only support for opt/llama/Llama-2/falcon now")
    torch.cuda.empty_cache()

    
    # same input of first layer for fp model and quant model
    quant_inps = inps
    fp_inps = copy.deepcopy(inps)   # take output of fp model as input
    fp_inps_2 = copy.deepcopy(inps) if args.aug_loss else None # take output of quantization model as input
    
    
    attention_mask = cache["attention_mask"]
    attention_mask_batch = attention_mask.repeat(args.batch_size,1,1,1) if args.deactive_amp else attention_mask.repeat(args.batch_size,1,1,1).float()
    loss_func = torch.nn.MSELoss()
    if is_llama:
        position_ids = cache["position_ids"]
    else:
        position_ids = None

    if args.resume:
        affine_parameters = torch.load(args.resume)
    else:
        affine_parameters = {}
    
    for i in range(len(layers)):
        logger.info(f"=== Start quantize layer {i} ===")
        layer = layers[i].to(dev)
        qlayer = DecoderLayer(lm.model.config, layer, args)

        with torch.no_grad():
            qlayer.to(args.dtype)
        # obtain output of full-precision model
        qlayer.set_quant_state(weight_quant=False, act_quant=False)
        if args.epochs > 0:
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    for j in range(args.nsamples):
                        fp_inps[j] = qlayer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]
                        if args.aug_loss:
                            fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]

        # init smooth parameters
        qlayer.set_quant_state(weight_quant=False, act_quant=True)  # weight will be manually quantized before forward
        qlayer.let = args.let
        use_shift = True 
        # if is_llama and args.abits == 16:
        #     use_shift = False                   # deactivate channel-wise shifting for llama weight-
        # use_shift = True if args.abits < 16 else False   # only activate per-channel shifting when weight-activation quantization

        use_matrix = args.use_matrix
        use_ln_matrix = args.use_ln_matrix
        if args.let:
            # init channel-wise scaling and shift
            if use_matrix:
                qlayer.register_parameter("qkt_smooth_scale",torch.nn.Parameter(torch.eye(layer.self_attn.q_proj.out_features,device=dev, dtype=dtype)))
            else:
                qlayer.register_parameter("qkt_smooth_scale",torch.nn.Parameter(torch.ones(layer.self_attn.q_proj.out_features,device=dev, dtype=dtype)))
            for name,module in qlayer.named_modules():
                if isinstance(module, QuantLinear):
                    for key in pairs.keys():
                        if key in name:
                            act = act_scales[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=torch.float16).clamp(min=1e-5)
                            weight = module.weight.max(dim=0)[0].clamp(min=1e-5)
                            scale = (act.pow(args.alpha)/weight.pow(1-args.alpha)).clamp(min=1e-5)
                            if use_shift and not is_llama:
                                shift = act_shifts[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=torch.float16)
                            else:
                                shift = torch.zeros_like(scale)
                            if (pairs[key] == "qkv" or pairs[key] == "fc1") and not use_ln_matrix:
                                qlayer.register_parameter(f"{pairs[key]}_smooth_shift",torch.nn.Parameter(shift.to(args.dtype)))
                                qlayer.register_parameter(f"{pairs[key]}_smooth_scale",torch.nn.Parameter(scale.to(args.dtype)))
                            else:
                                qlayer.register_parameter(f"{pairs[key]}_smooth_shift",torch.nn.Parameter(shift.to(args.dtype)))
                                qlayer.register_parameter(f"{pairs[key]}_smooth_scale",torch.nn.Parameter(torch.diag(scale.to(args.dtype))))

        if args.resume and i < len(affine_parameters):
            qlayer.load_state_dict(affine_parameters[i], strict=False)
        
        if args.epochs > 0 and not (args.resume and i < len(affine_parameters)):
            with torch.no_grad():
                qlayer.to(args.dtype)      # required for AMP training
            
            ##get sparsity mask
            ds_outs = copy.deepcopy(inps)
            if args.sparsity_method == "wanda":
                get_wanda_mask(args, qlayer, ds_outs, attention_mask = attention_mask, position_ids = position_ids)
            elif args.sparsity_method == "dsnot":
                get_dsnot_mask(args, qlayer, ds_outs, attention_mask = attention_mask, position_ids = position_ids)
            else:
                raise NotImplementedError()
            del ds_outs

            # create optimizer
            optimizer = torch.optim.AdamW(
                [{"params":qlayer.let_parameters(use_shift),"lr":args.let_lr}, {"params":qlayer.lwc_parameters(),"lr":args.lwc_lr}],weight_decay=args.wd)
            loss_scaler = utils.NativeScalerWithGradNormCount()
            
            for epochs in range(args.epochs):
                loss_list = []
                norm_list = []

                qkvmask_num = int((lm.model.config.hidden_size-1)/(args.epochs-1)*epochs)+1
                fc1mask_num = int((lm.model.config.hidden_size/lm.model.config.num_attention_heads-1)/(args.epochs-1)*epochs)+1
                
                values = torch.tensor([1 for i1 in range(qlayer.self_attn.q_proj.weight.data.size(1))]).cuda()
                maskqkv = torch.zeros(qlayer.self_attn.q_proj.weight.data.size(1), qlayer.self_attn.q_proj.weight.data.size(1)).cuda()
                for i1 in range(qkvmask_num):
                    if i1 == 0:
                        mask1 = torch.diag(values[:len(values)-i1], i1)
                        mask2 = torch.diag(values[:len(values)-i1], -i1)
                    else:
                        mask1 = torch.diag(0.1*values[:len(values)-i1], i1)
                        mask2 = torch.diag(0.1*values[:len(values)-i1], -i1)
                    maskqkv = maskqkv + mask1 + mask2
                maskqkv = maskqkv - torch.eye(qlayer.self_attn.q_proj.weight.data.size(1)).cuda()
                
                if "opt" in args.net.lower():
                    maskfc = torch.zeros([qlayer.self_attn.out_proj.weight.data.size(0), qlayer.self_attn.out_proj.weight.data.size(1)]).cuda()
                    head_size = qlayer.self_attn.out_proj.weight.data.size(0)//lm.model.config.num_attention_heads
                elif "llama" in args.net.lower():
                    maskfc = torch.zeros([qlayer.self_attn.o_proj.weight.data.size(0), qlayer.self_attn.o_proj.weight.data.size(1)]).cuda()
                    head_size = qlayer.self_attn.o_proj.weight.data.size(0)//lm.model.config.num_attention_heads
                
                values1 = torch.tensor([1 for i1 in range(head_size)]).cuda()
                ones = torch.zeros(head_size, head_size).cuda()
                for i1 in range(fc1mask_num):
                    if i1 == 0:
                        mask1 = torch.diag(values1[:len(values1)-i1], i1)
                        mask2 = torch.diag(values1[:len(values1)-i1], -i1)
                    else:
                        # import pdb;pdb.set_trace()
                        mask1 = torch.diag(0.1*values1[:len(values1)-i1], i1)
                        mask2 = torch.diag(0.1*values1[:len(values1)-i1], -i1)
                    ones = ones + mask1 + mask2
                ones = ones - torch.eye(head_size).cuda()
                for i1 in range(lm.model.config.num_attention_heads):
                    maskfc[i1*head_size:(i1+1)*head_size, i1*head_size:(i1+1)*head_size] = ones
                    
                

                for j in range(args.nsamples//args.batch_size): 
                
                    index = j * args.batch_size
                    # obtain output of quantization model
                    with traincast():
                        qlayer.smooth_and_quant_temporary(args, lm.model.config.num_attention_heads, maskqkv, maskfc, use_matrix=use_matrix, use_ln_matrix=use_ln_matrix)
                        quant_out = qlayer(quant_inps[index:index+args.batch_size,], attention_mask=attention_mask_batch,position_ids=position_ids)[0]
                        loss = loss_func(fp_inps[index:index+args.batch_size,], quant_out)
                        if args.aug_loss:
                            loss += loss_func(fp_inps_2[index:index+args.batch_size,], quant_out)
                        
                    if not math.isfinite(loss.item()):
                        logger.info("Loss is NAN, stopping training")
                        pdb.set_trace()
                        
                    loss_list.append(loss.data)
                    optimizer.zero_grad()
                    norm = loss_scaler(loss, optimizer,parameters=qlayer.affine_parameters(use_shift))
                    norm_list.append(norm.data)

                loss_mean = torch.stack(loss_list).mean()
                norm_mean = torch.stack(norm_list).mean()
                logger.info(f"layer {i} iter {epochs} loss:{loss_mean} norm:{norm_mean} max memory_allocated {torch.cuda.max_memory_allocated(lm._device) / 1024**2} ")

            qlayer.clear_temp_variable()
            del optimizer

        if args.resume and i < len(affine_parameters):
            qkvmask_num = lm.model.config.hidden_size
            fc1mask_num = lm.model.config.hidden_size//lm.model.config.num_attention_heads
            values = torch.tensor([1 for i1 in range(qlayer.self_attn.q_proj.weight.data.size(1))]).cuda()
            maskqkv = torch.zeros(qlayer.self_attn.q_proj.weight.data.size(1), qlayer.self_attn.q_proj.weight.data.size(1)).cuda()
            for i1 in range(qkvmask_num):
                if i1 == 0:
                    mask1 = torch.diag(values[:len(values)-i1], i1)
                    mask2 = torch.diag(values[:len(values)-i1], -i1)
                else:
                    mask1 = torch.diag(0.001*values[:len(values)-i1], i1)
                    mask2 = torch.diag(0.001*values[:len(values)-i1], -i1)
                maskqkv = maskqkv + mask1 + mask2
            maskqkv = maskqkv - torch.eye(qlayer.self_attn.q_proj.weight.data.size(1)).cuda()

            if "opt" in args.net.lower():
                maskfc = torch.zeros([qlayer.self_attn.out_proj.weight.data.size(0), qlayer.self_attn.out_proj.weight.data.size(1)]).cuda()
                head_size = qlayer.self_attn.out_proj.weight.data.size(0)//lm.model.config.num_attention_heads
            elif "llama" in args.net.lower():
                maskfc = torch.zeros([qlayer.self_attn.o_proj.weight.data.size(0), qlayer.self_attn.o_proj.weight.data.size(1)]).cuda()
                head_size = qlayer.self_attn.o_proj.weight.data.size(0)//lm.model.config.num_attention_heads
            
            values1 = torch.tensor([1 for i1 in range(head_size)]).cuda()
            ones = torch.zeros(head_size, head_size).cuda()
            for i1 in range(fc1mask_num):
                if i1 == 0:
                    mask1 = torch.diag(values1[:len(values1)-i1], i1)
                    mask2 = torch.diag(values1[:len(values1)-i1], -i1)
                else:
                   
                    mask1 = torch.diag(0.001*values1[:len(values1)-i1], i1)
                    mask2 = torch.diag(0.001*values1[:len(values1)-i1], -i1)
                ones = ones + mask1 + mask2
            ones = ones - torch.eye(head_size).cuda()
            for i1 in range(lm.model.config.num_attention_heads):
                maskfc[i1*head_size:(i1+1)*head_size, i1*head_size:(i1+1)*head_size] = ones
       
        


        # real smooth and quantization
        qlayer.smooth_and_quant_inplace(args, lm.model.config.num_attention_heads, maskqkv, maskfc,use_matrix=use_matrix,use_ln_matrix=use_ln_matrix)
        qlayer.clear_mask_variable()

        
        # check_sparsity
        check_layer_sparsity(qlayer):
        

        torch.cuda.empty_cache()
        if args.epochs>0:
            # update input of quantization model
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    for j in range(args.nsamples):
                        quant_inps[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]
            qlayer.register_scales_and_zeros()
            layers[i] = qlayer.to("cpu")
            affine_parameters[i] = qlayer.affine_state_dict()
            torch.save(affine_parameters, os.path.join(args.output_dir, "affine_parameters.pth"))
        else:
            qlayer.register_scales_and_zeros()
            qlayer.half()
            layers[i] = qlayer.to("cpu")
        if args.real_quant:
            named_linears = get_named_linears(qlayer)
            for name, module in named_linears.items():
                scales = module.weight_quantizer.scales
                zeros = module.weight_quantizer.zeros
                group_size = module.weight_quantizer.group_size
                dim0 = module.weight.shape[0]
                scales = scales.view(dim0,-1)
                zeros = zeros.view(dim0,-1)
                q_linear = qlinear_cuda.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None)
                q_linear.pack(module.float().cpu(),  scales.float().cpu(), zeros.float().cpu())
                
                levels = name.split('.')
                if len(levels) > 1:
                    mod_ = qlayer
                    for l_idx in range(len(levels)-1):
                        if levels[l_idx].isdigit():
                            mod_ = mod_[int(levels[l_idx])]
                        else:
                            mod_ = getattr(mod_, levels[l_idx])
                    setattr(mod_, levels[-1], q_linear)
                else:
                    setattr(qlayer, name, q_linear)        
                del module        

        del layer
        torch.cuda.empty_cache()

    del inps
    del quant_inps
    del fp_inps
    del fp_inps_2
    torch.cuda.empty_cache()
    gc.collect()                    
    model.config.use_cache = use_cache
    return model.half()

