import gc
from logging import getLogger
from typing import Tuple

import torch
from accelerate.utils import find_tied_parameters
from safetensors.torch import save_file as safe_save
from tqdm import tqdm

from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_marlin import _get_perms, unpack_qzeros
from ..quantization import CHECKPOINT_FORMAT, QUANT_METHOD, BaseQuantizeConfig
from .accelerate_utils import load_checkpoint_in_model
from .import_utils import MARLIN_AVAILABLE, MARLIN_EXCEPTION
from .modeling_utils import recurse_getattr, recurse_setattr


if MARLIN_AVAILABLE:
    import autogptq_marlin_cuda

logger = getLogger(__name__)


def prepare_model_for_marlin_load(
    model,
    quantize_config: BaseQuantizeConfig,
    quant_linear_class,
    torch_dtype,
    current_model_save_name,
    device_map,
):
    # The model (e.g. model.safetensors) is already serialized in the Marlin format, load it directly.
    if quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN:
        model_save_name = current_model_save_name
        logger.info(f"Loading a GPTQ model, detected Marlin serialized format at {model_save_name}.")
        model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)
    else:
        model_save_name, is_cached = quantize_config.get_cache_file_path(quant_method=QUANT_METHOD.GPTQ,
                                                              checkpoint_format=CHECKPOINT_FORMAT.MARLIN)

        # If GPTQ model has Marlin version cached locally, load from the cached version (no repacking needed).
        if is_cached:
            logger.info(
                f"Loading a GPTQ model, detected a cached repacked weight for Marlin kernel at {model_save_name}."
            )
            model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=False)

        # Otherwise, convert the model to Marlin format first and cache locally.
        else:
            # Loading the GPTQ checkpoint to do the conversion.
            # TODO: Avoid loading the model with wrong QuantLinear, and directly use
            # Marlin ones. The repacking can be done directly on the safetensors, just
            # as for AWQ checkpoints.
            load_checkpoint_in_model(
                model,
                dtype=torch_dtype,  # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
                checkpoint=current_model_save_name,
                device_map=device_map,
                offload_state_dict=True,
                offload_buffers=True,
            )
            # Convert model to marlin, repacking weights into Marlin format.
            model = convert_to_marlin(model, quant_linear_class, quantize_config, repack=True)

            # Safetensors is unable to save tied weights, so we untie them here. Reference: https://github.com/huggingface/safetensors/issues/202
            tied_params = find_tied_parameters(model)

            for weight_group in tied_params:
                for param_name in weight_group:
                    if isinstance(recurse_getattr(model, param_name), torch.nn.Parameter):
                        recurse_setattr(
                            model,
                            param_name,
                            torch.nn.Parameter(recurse_getattr(model, param_name).clone()),
                        )
                    else:
                        recurse_setattr(
                            model,
                            param_name,
                            recurse_getattr(model, param_name).clone(),
                        )

            # Cache the converted model.
            safe_save(model.state_dict(), model_save_name)

    return model, model_save_name


# Validate marlin support
def _validate_marlin_device_support() -> Tuple[bool, bool]:
    """
        Validates if the current device is compatible and optimized for Marlin.
        ref: https://github.com/IST-DASLab/marlin?tab=readme-ov-file#requirements

        Returns:
            Tuple[bool, bool]: The first indicates if CUDA device is compatible for Marlin,
                               the second indicates if CUDA device is optimized for Marlin.
        """
    supported = False
    optimized = False

    # >=hopper is compatible but not optimized
    if torch.cuda.get_device_capability()[0] >= 9:
        supported = True
        optimized = False
    # ampere and ada are supported and optimized
    elif torch.cuda.get_device_capability()[0] >= 8:
        supported = True
        optimized = True

    return supported, optimized


# Adapted from https://github.com/rib-2/marlin/tree/conversion
def _validate_marlin_compatibility(cfg: BaseQuantizeConfig):
    if not MARLIN_AVAILABLE:
        return f"AutoGPTQ is not compiled with the Marlin kernel, with the following error: {MARLIN_EXCEPTION}"
    if cfg.bits != 4:
        return f"The quantized model uses a bitwidth different than 4 (found {cfg.bits})"
    if cfg.group_size != 128 and cfg.group_size != -1:
        return "The quantized model uses a group size that is not 128 or -1 (found quantization_config.group_size)"
    if not cfg.sym:
        return "The quantized model uses asymmetric quantization"
    if cfg.desc_act:
        return "The quantized model uses act-order (also called desc-act) scheme"
    if cfg.quant_method == QUANT_METHOD.AWQ:
        return "awq_gemm format is currently not compatible with marlin"
    return None


@torch.no_grad()
def convert_to_marlin(model, model_quantlinear, quantization_config: BaseQuantizeConfig, repack: bool, strict: bool = False):
    """
    Converts GPTQ-packed weights to the Marlin format. This assumes that the model already meets Marlin kernel constraints.

    Arguments:
        repack (`bool`):
            Whether to repack the qweights from `model` into the Marlin's QuantLinear layers.
    """
    if repack:
        message = "Repacking weights to be compatible with Marlin kernel..."
    else:
        # TODO: load directly Marlin QuantLinear.
        message = "Overriding QuantLinear layers to use Marlin's QuantLinear..."

    for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
        if not isinstance(module, model_quantlinear):
            continue

        parent_name = ".".join(name.split(".")[:-1])
        layer_name = name[len(parent_name) + 1 :]

        # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when
        # loading weights from checkpoints holding zero bias.
        with torch.device("meta"):
            new_module = MarlinQuantLinear(
                bits=4,
                group_size=module.group_size,
                infeatures=module.infeatures,
                outfeatures=module.outfeatures,
                bias=module.bias is not None,
                trainable=False,
            )

        # workspace is never in the state_dict, thus we need to allocate it manually.
        new_module.workspace = torch.zeros(module.outfeatures // 128 * 16, dtype=torch.int, device=module.device)

        # Dequantize the weight.
        if repack:
            marlin_repacked_weight = autogptq_marlin_cuda.gptq_repack(module.qweight)

            if strict:
                dequantized_qzeros = unpack_qzeros(module.qzeros)

                if not torch.all(dequantized_qzeros == 8):
                    raise ValueError(
                        "Marlin kernel is compatible only with checkpoints using symmetric quantization."
                        "Found non-symmetric quantization for the weight {name}."
                    )


            _, _scale_perm, _scale_perm_single = _get_perms()

            s = module.scales.data.clone()
            if module.group_size != module.infeatures:
                s = s.reshape((1, -1))
                s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
            else:
                s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
            s = s.reshape((-1, module.outfeatures)).contiguous()

            new_module.B = marlin_repacked_weight
            new_module.s = s
            new_module.bias = module.bias

            new_module = new_module.to(module.device)

        # Save to parent.
        parent_module = model.get_submodule(parent_name)
        setattr(parent_module, layer_name, new_module)

        # Free cuda memory.
        del module
        if repack:
            del marlin_repacked_weight
        gc.collect()

    # Set quantization config to be Marlin.
    quantization_config.checkpoint_format = CHECKPOINT_FORMAT.MARLIN

    return model
