from .train_utils import cali_s2_quant
from .args_utils import get_config, create_logger
from .quant_linear import QuantizedLinear, reparametrize_to_gptq
from .s2_utils import load_s2_parameters
from .gptq_utils import gptq_fwrd, save_quantized_weights, load_quantized_weights
import torch.nn as nn
import os


def set_ignore_quantize(model, ignore_quantize=True):
    # for cogvideo-2b:
    model.patch_embed.text_proj.ignore_quantize = True
    model.time_embedding.linear_1.ignore_quantize = True
    model.time_embedding.linear_2.ignore_quantize = True
    model.norm_out.linear.ignore_quantize = True
    model.proj_out.ignore_quantize = True


def quantize_linear(module, device="cuda", args=None):
    if isinstance(module, nn.Linear):
        if device is not None:
            module = module.to(device)

        if getattr(module, 'ignore_quantize', False):
            return module

        if getattr(module, 'higher_bits', False):
            original_a_bits = args.a_bits
            args.a_bits = 16
            new_layer = QuantizedLinear(args, module)
            args.a_bits = original_a_bits
        else:
            new_layer = QuantizedLinear(args, module)
        return new_layer
    else:
        for name, child in module.named_children():
            new_child = quantize_linear(
                child, device, args
            )
            if new_child is not child:
                setattr(module, name, new_child)
        if device is not None:
            module.to(device=device)  # move parent module to device
        return module

def s2quant_model(model, calib_data, wbit, abit, resume_s2=False, use_gptq=False, resume_gptq=False, model_id=None, exp_name=None):
    # args, logger = parser_gen()
    config = get_config()
    # 设置量化位宽
    config.update_from_args(wbit, abit, model_id, exp_name)
    model.to("cuda")
    # 量化模型linear层
    set_ignore_quantize(model)
    quantize_linear(model, device=model.device, args=config)
    logger = create_logger(config.exp_dir)

    if resume_s2:
        load_s2_parameters(config, model)
    else:
        cali_s2_quant(config, model, calib_data, model.device, logger)

    if use_gptq and not resume_gptq:
        reparametrize_to_gptq(model)
        config.nsamples = 40
        quantized_weights = gptq_fwrd(model, calib_data, model.device, config)
        save_quantized_weights(quantized_weights, os.path.join(config.exp_dir, 'quantized_weights.pth'))

    if resume_gptq:
        reparametrize_to_gptq(model)
        load_quantized_weights(model, os.path.join(config.exp_dir, 'quantized_weights.pth'))

    return
