import copy
import json
import logging
import os
from dataclasses import dataclass, field, fields
from os.path import isdir, join
from typing import Dict, List, Optional, Union

import accelerate
import huggingface_hub
import torch
import torch.nn as nn
import transformers
from accelerate.hooks import remove_hook_from_module
from safetensors import safe_open
from safetensors.torch import load_file as safe_load
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import (
    CommitOperationAdd,
    PushToHubMixin,
    cached_file,
    create_commit,
    create_repo,
)

from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..nn_modules.qlinear import GeneralQuantLinear
from ..quantization import GPTQ
from ..utils.data_utils import collate_data
from ..utils.import_utils import (
    AUTOGPTQ_CUDA_AVAILABLE,
    EXLLAMA_KERNELS_AVAILABLE,
    EXLLAMAV2_KERNELS_AVAILABLE,
    MARLIN_AVAILABLE,
    QIGEN_AVAILABLE,
    TRITON_AVAILABLE,
    dynamically_import_QuantLinear,
)
from ..utils.marlin_utils import (
    _validate_marlin_compatibility,
    _validate_marlin_device_support,
    prepare_model_for_marlin_load,
)
from ._const import CPU, CUDA_0, SUPPORTED_MODELS
from ._utils import (
    autogptq_post_init,
    find_layers,
    get_checkpoints,
    get_device,
    get_module_by_name_prefix,
    get_module_by_name_suffix,
    make_quant,
    make_sure_no_tensor_in_meta_device,
    move_to_device,
    pack_from_tensors,
    pack_model,
    preprocess_checkpoint_qigen,
    simple_dispatch_model,
    unpack_awq,
)


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

SYNONYMS = {
    "w_bit": "bits",
    "q_group_size": "group_size",
}


@dataclass
class BaseQuantizeConfig(PushToHubMixin):
    bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
    group_size: int = field(default=-1)
    damp_percent: float = field(default=0.01)
    desc_act: bool = field(default=True)
    static_groups: bool = field(default=False)
    sym: bool = field(default=True)
    true_sequential: bool = field(default=True)
    is_marlin_format: bool = field(default=False)
    model_name_or_path: Optional[str] = field(default=None)
    model_file_base_name: Optional[str] = field(default=None)
    awq_gemm_checkpoint: Optional[bool] = field(default=False)

    def __post_init__(self):
        fields_info = fields(self)

        if self.bits not in fields_info[0].metadata["choices"]:
            raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.")
        if self.group_size != -1 and self.group_size <= 0:
            raise ValueError("unless equal to -1, group_size must greater then 0.")
        if not (0 < self.damp_percent < 1):
            raise ValueError("damp_percent must between 0 and 1.")

    def save_pretrained(self, save_dir: str, **kwargs):
        with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def from_pretrained(cls, save_dir: str, **kwargs):
        # Parameters related to loading from Hugging Face Hub
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)
        commit_hash = kwargs.pop("_commit_hash", None)

        transformers_config = False
        for quantize_config_filename in [
            "quantize_config.json",
            "quant_config.json",
            "config.json",
        ]:
            if os.path.isdir(save_dir):  # Local
                resolved_config_file = join(save_dir, quantize_config_filename)
            else:  # Remote
                resolved_config_file = cached_file(
                    save_dir,
                    quantize_config_filename,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                    use_auth_token=use_auth_token,
                    revision=revision,
                    local_files_only=local_files_only,
                    subfolder=subfolder,
                    _raise_exceptions_for_missing_entries=False,
                    _raise_exceptions_for_connection_errors=False,
                    _commit_hash=commit_hash,
                )
            if resolved_config_file is not None:
                if quantize_config_filename == "config.json":
                    transformers_config = True
                break

        if resolved_config_file is None:
            raise ValueError(
                "No quantize_config.json, quant_config.json or config.json file was found in the model repository."
            )

        field_names = [field.name for field in fields(cls)]
        with open(resolved_config_file, "r", encoding="utf-8") as f:
            args_from_json = json.load(f)

            if transformers_config:
                args_from_json = args_from_json["quantization_config"]

            filtered_args = {"awq_gemm_checkpoint": False}
            for key, val in args_from_json.items():
                if key == "version" and val == "GEMM":
                    filtered_args["awq_gemm_checkpoint"] = True
                elif key in field_names:
                    filtered_args[key] = val
                elif key in SYNONYMS and SYNONYMS[key] in field_names:
                    filtered_args[SYNONYMS[key]] = val
                else:
                    logger.warning(f"ignoring unknown parameter in {quantize_config_filename}: {key}.")

            if filtered_args["awq_gemm_checkpoint"]:
                # AWQ does not reorder the rows.
                filtered_args["desc_act"] = False

            if "sym" not in args_from_json:
                logger.warning(
                    f"The quantization configuration {quantize_config_filename} does not contain an entry `sym` (symetric quantization). This may result in silent errors."
                )

            return cls(**filtered_args)

    def to_dict(self):
        return {
            "bits": self.bits,
            "group_size": self.group_size,
            "damp_percent": self.damp_percent,
            "desc_act": self.desc_act,
            "static_groups": self.static_groups,
            "sym": self.sym,
            "true_sequential": self.true_sequential,
            "model_name_or_path": self.model_name_or_path,
            "model_file_base_name": self.model_file_base_name,
            "is_marlin_format": self.is_marlin_format,
            "quant_method": "gptq",
        }


class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
    layer_type: str = None
    layers_block_name: str = None
    outside_layer_modules: List[str] = None
    inside_layer_modules: List[List[str]] = None
    lm_head_name: str = "lm_head"

    fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
    fused_mlp_module_type: Optional[FusedBaseMLPModule] = None

    def __init__(
        self,
        model: PreTrainedModel,
        quantized: bool,
        quantize_config: BaseQuantizeConfig,
        is_triton_backend: bool = False,
        injected_fused_attention: bool = False,
        injected_fused_mlp: bool = False,
        trainable: bool = False,
    ):
        super().__init__()

        self.model = model
        self.model_type = self.model.config.model_type
        self._quantized = quantized
        self.quantize_config = quantize_config
        self.config = self.model.config

        self.is_triton_backend = is_triton_backend
        self.injected_fused_attention = injected_fused_attention
        self.injected_fused_mlp = injected_fused_mlp
        self.trainable = trainable

    @property
    def quantized(self):
        return self._quantized

    @property
    def hf_device_map(self):
        return getattr(self.model, "hf_device_map", None)

    def _prepare_examples_for_quantization(
        self,
        examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
        batch_size: int = 1,
    ):
        def _convert_tensor_to_list(tensor):
            if isinstance(tensor, torch.Tensor):
                if len(tensor.shape) == 1:
                    tensor = tensor.unsqueeze(0)
                tensor = tensor.long()
                return tensor.cpu().numpy().tolist()
            return [tensor]

        new_examples = []
        for example in examples:
            input_ids = _convert_tensor_to_list(example["input_ids"])
            attention_mask = _convert_tensor_to_list(example["attention_mask"])
            if "labels" in example:
                labels = _convert_tensor_to_list(example["labels"])
            elif "label" in example:
                labels = _convert_tensor_to_list(example["label"])
            elif "label_ids" in example:
                labels = _convert_tensor_to_list(example["label_ids"])
            else:
                labels = copy.deepcopy(input_ids)
            new_examples.append(
                {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels,
                }
            )
        pad_token_id = self.config.pad_token_id
        if not pad_token_id:
            pad_token_id = self.config.eos_token_id

        new_examples = [
            collate_data(new_examples[start : start + batch_size], pad_token_id)
            for start in range(0, len(new_examples), batch_size)
        ]
        for new_example in new_examples:
            del new_example["labels"]

        return new_examples

    @torch.inference_mode()
    def quantize(
        self,
        examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
        batch_size: int = 1,
        use_triton: bool = False,
        use_cuda_fp16: bool = True,
        autotune_warmup_after_quantized: bool = False,
        cache_examples_on_gpu: bool = True,
    ):
        if self.quantized:
            raise EnvironmentError("can't execute quantize because the model is quantized.")
        if use_triton and not TRITON_AVAILABLE:
            logger.warning("triton is not installed, reset use_triton to False")
            use_triton = False

        device_map = self.hf_device_map
        if device_map:
            for name, device in device_map.items():
                if device == "cpu":
                    logger.info(f"truly offloading {name} to cpu with hook.")
                    module = get_module_by_name_suffix(self.model, name)
                    remove_hook_from_module(module, recurse=True)
                    accelerate.cpu_offload_with_hook(module, CUDA_0)

        layer_inputs = []
        attention_masks = []
        position_ids = []
        layer_input_kwargs = []
        layer_outputs = []

        examples = self._prepare_examples_for_quantization(examples, batch_size)

        def nested_move_to_device(v, device):
            if isinstance(v, torch.Tensor):
                return move_to_device(v, device)
            elif isinstance(v, (list, tuple)):
                return type(v)([nested_move_to_device(e, device) for e in v])
            else:
                return v

        class LayerHijacker(nn.Module):
            """hijack layer's forward pass to cache data"""

            def __init__(self, m, device):
                super().__init__()
                self.module = m
                self.data_device = device if cache_examples_on_gpu else CPU

            def forward(self, inp=None, **kwargs):
                if inp is None:  # some models use all key-value arguments in forward pass call
                    for kwarg_name in ["hidden_states"]:
                        if kwarg_name in kwargs:
                            inp = kwargs[kwarg_name]
                            break
                layer_inputs.append(move_to_device(inp, self.data_device))

                if kwargs["attention_mask"] is not None:
                    attention_masks.append(kwargs["attention_mask"].to(self.data_device))
                else:
                    attention_masks.append(None)

                pos_ids = kwargs.get("position_ids", None)
                if pos_ids is not None:
                    position_ids.append(move_to_device(pos_ids, self.data_device))
                one_kwargs = {}
                for (
                    k,
                    v,
                ) in kwargs.items():  # make sure other arguments also be captured
                    if k not in ["hidden_states", "attention_mask", "position_ids"]:
                        one_kwargs[k] = nested_move_to_device(v, self.data_device)
                layer_input_kwargs.append(one_kwargs)
                raise ValueError

        forward_pass_use_cache = self.model.config.use_cache
        self.model.config.use_cache = False

        num_batches = len(examples)
        layers = get_module_by_name_prefix(self.model, self.layers_block_name)

        force_layer_back_to_cpu = False
        if get_device(layers[0]) == CPU:
            layers[0] = layers[0].to(CUDA_0)
            force_layer_back_to_cpu = True

        cur_layer_device = get_device(layers[0])
        ori_outside_layer_module_devices = {}
        for module_name in self.outside_layer_modules:
            module = get_module_by_name_prefix(self.model, module_name)

            if module is None:
                continue

            ori_outside_layer_module_devices[module_name] = get_device(module)
            if module is not None:
                move_to_device(module, cur_layer_device)

        # get inputs for first layer
        layers[0] = LayerHijacker(layers[0], cur_layer_device)
        for example in examples:
            for k, v in example.items():
                if len(v.shape) == 1:
                    v = v.unsqueeze(0)
                example[k] = move_to_device(v, cur_layer_device)
            try:
                self.model(**example)
            except ValueError:
                pass
        layers[0] = layers[0].module

        move_to_device(layers[0], CPU if force_layer_back_to_cpu else cur_layer_device)
        for module_name in self.outside_layer_modules:
            module = get_module_by_name_prefix(self.model, module_name)
            if module is not None:
                move_to_device(module, ori_outside_layer_module_devices[module_name])

        torch.cuda.empty_cache()

        inside_layer_modules = self.inside_layer_modules
        if not self.quantize_config.true_sequential:
            inside_layer_modules = [sum(inside_layer_modules, [])]
        quantizers = {}
        for i in range(len(layers)):
            logger.info(f"Start quantizing layer {i + 1}/{len(layers)}")
            layer = layers[i]
            force_layer_back_to_cpu = False
            if get_device(layer) == CPU:
                move_to_device(layer, CUDA_0)
                force_layer_back_to_cpu = True
            cur_layer_device = get_device(layer)

            full = find_layers(layer)
                            
            for names in inside_layer_modules:
                subset = {n: full[n] for n in names if n in full}
                gptq = {}
                for name in subset:
                    gptq[name] = GPTQ(subset[name])
                    gptq[name].quantizer.configure(
                        self.quantize_config.bits,
                        perchannel=True,
                        sym=self.quantize_config.sym,
                        mse=False,
                    )

                def add_batch(name):
                    def tmp(_, inp, out):
                        # gptq is mutable.
                        gptq[name].add_batch(inp[0].data, out.data)  # noqa: F821

                    return tmp

                handles = []
                for name in subset:
                    handles.append(subset[name].register_forward_hook(add_batch(name)))
                for j in range(num_batches):
                    layer_input = move_to_device(layer_inputs[j], cur_layer_device)
                    layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device)
                    additional_layer_inputs = {"attention_mask": layer_attention_mask}
                    layer_position_ids = (
                        None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
                    )
                    if layer_position_ids is not None:
                        additional_layer_inputs["position_ids"] = layer_position_ids
                    for k, v in layer_input_kwargs[j].items():
                        additional_layer_inputs[k] = nested_move_to_device(v, cur_layer_device)
                    layer(layer_input, **additional_layer_inputs)
                for h in handles:
                    h.remove()

                for name in subset:
                    logger.info(f"Quantizing {name} in layer {i + 1}/{len(layers)}...")
                    scale, zero, g_idx = gptq[name].fasterquant(
                        percdamp=self.quantize_config.damp_percent,
                        group_size=self.quantize_config.group_size,
                        actorder=self.quantize_config.desc_act,
                        static_groups=self.quantize_config.static_groups,
                    )
                    quantizers[f"{self.layers_block_name}.{i}.{name}"] = (
                        gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device),
                        move_to_device(scale, CPU if force_layer_back_to_cpu else cur_layer_device),
                        move_to_device(zero, CPU if force_layer_back_to_cpu else cur_layer_device),
                        move_to_device(g_idx, CPU if force_layer_back_to_cpu else cur_layer_device),
                    )
                    gptq[name].free()

            for j in range(num_batches):
                layer_input = move_to_device(layer_inputs[j], cur_layer_device)
                layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device)
                additional_layer_inputs = {"attention_mask": layer_attention_mask}
                layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
                if layer_position_ids is not None:
                    additional_layer_inputs["position_ids"] = layer_position_ids
                for k, v in layer_input_kwargs[j].items():
                    additional_layer_inputs[k] = nested_move_to_device(v, cur_layer_device)
                layer_output = move_to_device(
                    layer(layer_input, **additional_layer_inputs)[0],
                    cur_layer_device if cache_examples_on_gpu else CPU,
                )
                layer_outputs.append(layer_output)

            layers[i] = move_to_device(layer, CPU if force_layer_back_to_cpu else cur_layer_device)
            del layer
            del gptq
            del layer_inputs
            layer_inputs, layer_outputs = layer_outputs, []
            torch.cuda.empty_cache()

        pack_model(
            model=self.model,
            quantizers=quantizers,
            bits=self.quantize_config.bits,
            group_size=self.quantize_config.group_size,
            use_triton=use_triton,
            use_cuda_fp16=use_cuda_fp16,
            desc_act=self.quantize_config.desc_act,
            warmup_triton=autotune_warmup_after_quantized,
            force_layer_back_to_cpu=force_layer_back_to_cpu,
        )
        if device_map:
            self.model = remove_hook_from_module(self.model, recurse=True)
            self.model = simple_dispatch_model(self.model, device_map)
        self.model.config.use_cache = forward_pass_use_cache

        self._quantized = True

        torch.cuda.empty_cache()

    @property
    def device(self):
        if not self.hf_device_map:
            return self.model.device
        else:
            device = [d for d in self.hf_device_map.values() if d not in {"disk"}][0]
            return torch.device(device)

    def to(self, device: Union[str, torch.device]):
        self.model.to(device)
        return self

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def generate(self, **kwargs):
        """shortcut for model.generate"""
        with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
            return self.model.generate(**kwargs)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        """shortcut for model.prepare_inputs_for_generation"""
        return self.model.prepare_inputs_for_generation(*args, **kwargs)

    def push_to_hub(
        self,
        repo_id: str,
        save_dir: Optional[str] = None,
        use_safetensors: Optional[bool] = True,
        safetensors_metadata: Optional[Dict[str, str]] = None,
        commit_message: Optional[str] = "Upload of AutoGPTQ quantized model",
        use_auth_token: Optional[Union[bool, str]] = None,
        private: Optional[bool] = None,
        token: Optional[Union[bool, str]] = None,
        create_pr: Optional[bool] = False,
    ) -> str:
        """
        Upload the model to the Hugging Face Hub.

        Parameters:
            repo_id (`str`):
                The name of the repository you want to push your tool to. It should contain your organization name when
                pushing to a given organization.
            save_dir (`str`, *optional*):
                The name of the local folder to save the model to.
                If the model has already been saved, this parameter can be omitted.
            use_safetensors (`bool`, *optional*):
                Save the model using `safetensors`.
                If the model has already been saved, this parameter can be omitted.
            safetensors_metadata: (`dict`, *optional*, defaults to `None`):
                Pass optional metadata dictionary to be saved in the `safetensors` model file(s).
                Metadata is optional and is purely for informational purposes. It does not affect inference.
                If `None`, no metadata will be saved.
            commit_message (`str`, *optional*, defaults to `"Upload tool"`):
                Message to commit while pushing.
            use_auth_token (`bool` or `str`, *optional*):
                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
                when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
                is not specified.
            private (`bool`, *optional*):
                Whether or not the repository created should be private.
            token (`bool` or `str`, *optional*):
                The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
                when running `huggingface-cli login` (stored in `~/.huggingface`).
            create_pr (`bool`, *optional*, defaults to `False`):
                Whether or not to create a PR with the uploaded files or directly commit.
        """
        if (
            self.quantize_config.model_name_or_path is None or not isdir(self.quantize_config.model_name_or_path)
        ) and save_dir is None:
            raise ValueError(
                "Quantized model should be saved first, or you can provide save_dir to make sure model is saved to local disk before uploading."
            )

        if save_dir is not None:
            logger.info(f"Saving model to {save_dir}")
            self.save_quantized(save_dir, use_safetensors, safetensors_metadata)

        repo_url = create_repo(
            repo_id=repo_id,
            token=token,
            private=private,
            exist_ok=True,
            repo_type="model",
        )
        repo_id = repo_url.repo_id

        if self.quantize_config.model_name_or_path is not None:
            work_dir = self.quantize_config.model_name_or_path
            operations = [
                CommitOperationAdd(path_or_fileobj=join(work_dir, f), path_in_repo=f) for f in os.listdir(work_dir)
            ]
            logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
            return create_commit(
                repo_id=repo_id,
                operations=operations,
                commit_message=commit_message,
                token=use_auth_token,
                create_pr=create_pr,
                repo_type="model",
            )

    def save_quantized(
        self,
        save_dir: str,
        use_safetensors: bool = True,
        safetensors_metadata: Optional[Dict[str, str]] = None,
    ):
        """save quantized model and configs to local disk"""
        os.makedirs(save_dir, exist_ok=True)

        if not self.quantized:
            raise EnvironmentError("can only save quantized model, please execute .quantize first.")

        self.model.to(CPU)

        model_base_name = (
            self.quantize_config.model_file_base_name
            or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
        )
        if use_safetensors:
            model_save_name = model_base_name + ".safetensors"
            state_dict = self.model.state_dict()
            state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
            if safetensors_metadata is None:
                safetensors_metadata = {}
            elif not isinstance(safetensors_metadata, dict):
                raise TypeError("safetensors_metadata must be a dictionary.")
            else:
                logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
                new_safetensors_metadata = {}
                converted_keys = False
                for key, value in safetensors_metadata.items():
                    if not isinstance(key, str) or not isinstance(value, str):
                        converted_keys = True
                        try:
                            new_key = str(key)
                            new_value = str(value)
                        except Exception as e:
                            raise TypeError(
                                f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
                            )
                        if new_key in new_safetensors_metadata:
                            logger.warning(
                                f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
                            )
                        new_safetensors_metadata[new_key] = new_value
                safetensors_metadata = new_safetensors_metadata
                if converted_keys:
                    logger.debug(
                        f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
                    )

            # Format is required to enable Accelerate to load the metadata
            # otherwise it raises an OSError
            safetensors_metadata["format"] = "pt"

            # Store the quantization configuration as safetensors metadata
            from auto_gptq import __version__

            safetensors_metadata["auto_gptq_version"] = str(__version__)
            safetensors_metadata["gptq_bits"] = str(self.quantize_config.bits)
            safetensors_metadata["gptq_group_size"] = str(self.quantize_config.group_size)
            safetensors_metadata["gptq_desc_act"] = str(self.quantize_config.desc_act)
            safetensors_metadata["gptq_damp_percent"] = str(self.quantize_config.damp_percent)
            safetensors_metadata["gptq_is_marlin_format"] = str(self.quantize_config.is_marlin_format)

            safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
        else:
            model_save_name = model_base_name + ".bin"
            torch.save(self.model.state_dict(), join(save_dir, model_save_name))

        self.model.config.quantization_config = self.quantize_config.to_dict()
        self.model.config.save_pretrained(save_dir)
        self.quantize_config.save_pretrained(save_dir)
        self.quantize_config.model_name_or_path = save_dir
        self.quantize_config.model_file_base_name = model_base_name

    def save_pretrained(
        self,
        save_dir: str,
        use_safetensors: bool = True,
        safetensors_metadata: Optional[Dict[str, str]] = None,
        **kwargs,
    ):
        """alias of save_quantized"""
        logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
        self.save_quantized(save_dir, use_safetensors, safetensors_metadata)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        quantize_config: BaseQuantizeConfig,
        max_memory: Optional[dict] = None,
        trust_remote_code: bool = False,
        torch_dtype: torch.dtype = torch.float16,
        **model_init_kwargs,
    ):
        """load un-quantized pretrained model to cpu"""

        if not torch.cuda.is_available():
            raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")

        def skip(*args, **kwargs):
            pass

        torch.nn.init.kaiming_uniform_ = skip
        torch.nn.init.uniform_ = skip
        torch.nn.init.normal_ = skip

        # Parameters related to loading from Hugging Face Hub
        cache_dir = model_init_kwargs.pop("cache_dir", None)
        force_download = model_init_kwargs.pop("force_download", False)
        resume_download = model_init_kwargs.pop("resume_download", False)
        proxies = model_init_kwargs.pop("proxies", None)
        local_files_only = model_init_kwargs.pop("local_files_only", False)
        use_auth_token = model_init_kwargs.pop("use_auth_token", None)
        revision = model_init_kwargs.pop("revision", None)
        subfolder = model_init_kwargs.pop("subfolder", "")
        commit_hash = model_init_kwargs.pop("_commit_hash", None)

        cached_file_kwargs = {
            "cache_dir": cache_dir,
            "force_download": force_download,
            "proxies": proxies,
            "resume_download": resume_download,
            "local_files_only": local_files_only,
            "use_auth_token": use_auth_token,
            "revision": revision,
            "subfolder": subfolder,
            "_commit_hash": commit_hash,
        }

        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs
        )
        if config.model_type not in SUPPORTED_MODELS:
            raise TypeError(f"{config.model_type} isn't supported yet.")

        # enforce some values despite user specified
        model_init_kwargs["torch_dtype"] = torch_dtype
        model_init_kwargs["trust_remote_code"] = trust_remote_code
        if max_memory:
            if "disk" in max_memory:
                raise NotImplementedError("disk offload not support yet.")
            with accelerate.init_empty_weights():
                model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
            model.tie_weights()

            max_memory = accelerate.utils.get_balanced_memory(
                model,
                max_memory=max_memory,
                no_split_module_classes=[cls.layer_type],
                dtype=model_init_kwargs["torch_dtype"],
                low_zero=False,
            )
            model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
                model,
                max_memory=max_memory,
                no_split_module_classes=[cls.layer_type],
                dtype=model_init_kwargs["torch_dtype"],
            )
            model_init_kwargs["low_cpu_mem_usage"] = True

            del model
        else:
            model_init_kwargs["device_map"] = None
            model_init_kwargs["low_cpu_mem_usage"] = False

        torch.cuda.empty_cache()

        merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
        model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)

        model_config = model.config.to_dict()
        seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
        if any(k in model_config for k in seq_len_keys):
            for key in seq_len_keys:
                if key in model_config:
                    model.seqlen = model_config[key]
                    break
        else:
            logger.warning("can't get model's sequence length from model config, will set to 4096.")
            model.seqlen = 4096
        model.eval()

        return cls(model, False, quantize_config)

    @classmethod
    def from_quantized(
        cls,
        model_name_or_path: Optional[str],
        device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
        max_memory: Optional[dict] = None,
        device: Optional[Union[str, int]] = None,
        low_cpu_mem_usage: bool = False,
        use_triton: bool = False,
        use_qigen: bool = False,
        use_marlin: bool = False,
        torch_dtype: Optional[torch.dtype] = None,
        inject_fused_attention: bool = False,
        inject_fused_mlp: bool = False,
        use_cuda_fp16: bool = True,
        quantize_config: Optional[BaseQuantizeConfig] = None,
        model_basename: Optional[str] = None,
        use_safetensors: bool = True,
        trust_remote_code: bool = False,
        warmup_triton: bool = False,
        trainable: bool = False,
        disable_exllama: Optional[bool] = None,
        disable_exllamav2: bool = False,
        **kwargs,
    ):
        """load quantized model from local disk"""
        # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
        if disable_exllama is None:
            if disable_exllamav2:
                disable_exllama = False
            else:
                disable_exllama = True

        # Parameters related to loading from Hugging Face Hub
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", "")
        commit_hash = kwargs.pop("_commit_hash", None)

        cached_file_kwargs = {
            "cache_dir": cache_dir,
            "force_download": force_download,
            "proxies": proxies,
            "resume_download": resume_download,
            "local_files_only": local_files_only,
            "use_auth_token": use_auth_token,
            "revision": revision,
            "subfolder": subfolder,
            "_raise_exceptions_for_missing_entries": False,
            "_commit_hash": commit_hash,
        }
        if use_qigen and not QIGEN_AVAILABLE:
            logger.warning("Qigen is not installed, reset use_qigen to False.")
            use_qigen = False
        if use_triton and not TRITON_AVAILABLE:
            logger.warning("Triton is not installed, reset use_triton to False.")
            use_triton = False
        if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
            logger.warning(
                "Exllama kernel is not installed, reset disable_exllama to True. "
                "This may because you installed auto_gptq using a pre-build wheel "
                "on Windows, in which exllama_kernels are not compiled. To use "
                "exllama_kernels to further speedup inference, you can re-install "
                "auto_gptq from source."
            )
            disable_exllama = True
        if not disable_exllamav2 and not EXLLAMAV2_KERNELS_AVAILABLE:
            logger.warning(
                "Exllamav2 kernel is not installed, reset disable_exllamav2 to True. "
                "This may because you installed auto_gptq using a pre-build wheel "
                "on Windows, in which exllama_kernels are not compiled. To use "
                "exllama_kernels to further speedup inference, you can re-install "
                "auto_gptq from source."
            )
            disable_exllamav2 = True
        if not AUTOGPTQ_CUDA_AVAILABLE:
            logger.warning(
                "CUDA kernels for auto_gptq are not installed, this will result in "
                "very slow inference speed. This may because:\n"
                "1. You disabled CUDA extensions compilation by setting BUILD_CUDA_EXT=0 when install auto_gptq from source.\n"
                "2. You are using pytorch without CUDA support.\n"
                "3. CUDA and nvcc are not installed in your device."
            )

        if use_qigen and QIGEN_AVAILABLE:
            logger.warning("QIgen is active. Ignores all settings related to cuda.")
            inject_fused_attention = False
            inject_fused_mlp = False
            use_triton = False
            disable_exllama = True
            disable_exllamav2 = True

        if not disable_exllamav2 and not disable_exllama:
            logger.warning(
                "You have activated both exllama and exllamav2 kernel. Setting disable_exllama to True and keeping disable_exllamav2 to False"
            )
            disable_exllama = True

        # == step1: prepare configs and file names == #
        config = AutoConfig.from_pretrained(
            model_name_or_path,
            trust_remote_code=trust_remote_code,
            **cached_file_kwargs,
        )

        if config.model_type not in SUPPORTED_MODELS:
            raise TypeError(f"{config.model_type} isn't supported yet.")

        if quantize_config is None:
            quantize_config = BaseQuantizeConfig.from_pretrained(model_name_or_path, **cached_file_kwargs, **kwargs)

        if not use_marlin and MARLIN_AVAILABLE:
            unsupported_reason = _validate_marlin_compatibility(quantize_config)
            if unsupported_reason is None and _validate_marlin_device_support():
                logger.info(
                    "You passed a model that is compatible with the Marlin int4*fp16 GPTQ kernel but use_marlin is False. We recommend using `use_marlin=True` to use the optimized Marlin kernels for inference. Example: `model = AutoGPTQForCausalLM.from_quantized(..., use_marlin=True)`."
                )

        if hasattr(quantize_config, "is_marlin_format") and quantize_config.is_marlin_format and not use_marlin:
            raise ValueError(
                "You passed a GPTQ model saved in int4*fp16 GPTQ Marlin kernel format but are loading with use_marlin=False. "
                "Please use `use_marlin=True` to load this model. Example: `model = AutoGPTQForCausalLM.from_quantized(..., use_marlin=True)`."
            )

        if model_basename is None:
            if quantize_config.model_file_base_name:
                possible_model_basenames = [quantize_config.model_file_base_name]
            else:
                possible_model_basenames = [
                    f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g",
                    "model",
                ]
        else:
            possible_model_basenames = [model_basename]

        quantize_config.model_name_or_path = model_name_or_path

        extensions = []
        if use_safetensors:
            extensions.append(".safetensors")
        else:
            extensions += [".bin", ".pt"]

        model_name_or_path = str(model_name_or_path)
        is_local = isdir(model_name_or_path)

        # Retrieve (and if necessary download) the quantized checkpoint(s).
        is_sharded, resolved_archive_file, true_model_basename = get_checkpoints(model_name_or_path=model_name_or_path, extensions=extensions, possible_model_basenames=possible_model_basenames, **cached_file_kwargs)

        quantize_config.model_file_base_name = true_model_basename

        model_save_name = resolved_archive_file  # In case a model is sharded, this would be `model.safetensors.index.json` which may later break.

        if (not disable_exllama or not disable_exllamav2) and trainable:
            logger.warning(
                "QuantLinear with the exllama backend not does support the trainable mode yet, switching to cuda/cuda_old/triton backend."
            )
            disable_exllama = True
            disable_exllamav2 = True

        elif not use_triton and trainable:
            logger.warning(
                "QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend."
            )

        # == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
        def skip(*args, **kwargs):
            pass

        if torch_dtype is None:
            if not use_qigen:
                torch_dtype = torch.float16
            else:
                torch_dtype = torch.float32

        if torch_dtype != torch.float16:
            logger.warning("Overriding use_cuda_fp16 to False since torch_dtype is not torch.float16.")
            use_cuda_fp16 = False

        if not use_qigen:
            torch.nn.init.kaiming_uniform_ = skip
            torch.nn.init.uniform_ = skip
            torch.nn.init.normal_ = skip

            transformers.modeling_utils._init_weights = False

            init_contexts = [no_init_weights()]
            if low_cpu_mem_usage:
                init_contexts.append(accelerate.init_empty_weights(include_buffers=False))

            with ContextManagers(init_contexts):
                model = AutoModelForCausalLM.from_config(
                    config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
                )

                layers = find_layers(model)
                ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
                for name in list(layers.keys()):
                    if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all(
                        not name.endswith(ignore_layer)
                        for sublist in cls.inside_layer_modules
                        for ignore_layer in sublist
                    ):
                        logger.info(f"The layer {name} is not quantized.")
                        del layers[name]

                make_quant(
                    model,
                    layers,
                    quantize_config.bits,
                    quantize_config.group_size,
                    use_triton=use_triton,
                    disable_exllama=disable_exllama,
                    disable_exllamav2=disable_exllamav2,
                    use_cuda_fp16=use_cuda_fp16,
                    desc_act=quantize_config.desc_act,
                    trainable=trainable,
                )
                model.tie_weights()

            # == step3: load checkpoint and dispatch == #
            if isinstance(device_map, str) and device_map not in [
                "auto",
                "balanced",
                "balanced_low_0",
                "sequential",
            ]:
                raise ValueError(
                    "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
                    "'sequential'."
                )
            if isinstance(device_map, dict):
                max_memory = None
            else:
                if device is None and not device_map and not max_memory:
                    device_map = "auto"
                if device is not None:
                    device = torch.device(device)
                    if not max_memory and not device_map:
                        device_map = {"": device.index if device.type == "cuda" else device.type}
                if not isinstance(device_map, dict) and device_map != "sequential":
                    max_memory = accelerate.utils.get_balanced_memory(
                        model=model,
                        max_memory=max_memory,
                        no_split_module_classes=[cls.layer_type],
                        low_zero=(device_map == "balanced_low_0"),
                    )
            if not isinstance(device_map, dict):
                device_map = accelerate.infer_auto_device_map(
                    model,
                    max_memory=max_memory,
                    no_split_module_classes=[cls.layer_type],
                )

            if low_cpu_mem_usage:
                make_sure_no_tensor_in_meta_device(
                    model,
                    use_triton,
                    quantize_config.desc_act,
                    quantize_config.group_size,
                    bits=quantize_config.bits,
                    disable_exllama=disable_exllama,
                    disable_exllamav2=disable_exllamav2,
                )

            # TODO: move this logic in an awq_utils.py file.
            if quantize_config.awq_gemm_checkpoint:
                if is_sharded:
                    raise ValueError("The loading of sharded checkpoints with AWQ checkpoints is currently not supported. Please raise an issue in AutoGPTQ repository.")

                if use_marlin:
                    raise ValueError(
                        "Tried to load an AWQ kernel with use_marlin=True. This is currently not supported. Please open an issue in AutoGPTQ repository."
                    )

                if is_local:
                    is_cached = os.path.isfile(os.path.join(model_name_or_path, "autogptq_model.safetensors"))
                else:
                    namespace, subfolder = model_name_or_path.split("/")
                    assets_path = huggingface_hub.cached_assets_path(
                        library_name="autogptq",
                        namespace=namespace,
                        subfolder=subfolder,
                    )
                    weight_path = os.path.join(assets_path, "autogptq_model.safetensors")

                    is_cached = os.path.isfile(weight_path)

                if is_cached:
                    if is_local:
                        model_save_name = os.path.join(model_name_or_path, "autogptq_model.safetensors")
                    else:
                        namespace, subfolder = model_name_or_path.split("/")
                        assets_path = huggingface_hub.cached_assets_path(
                            library_name="autogptq",
                            namespace=namespace,
                            subfolder=subfolder,
                        )
                        model_save_name = os.path.join(assets_path, "autogptq_model.safetensors")

                    logger.info(f"Loading an AWQ model, detected a cached repacked weight at {model_save_name}.")
                else:
                    logger.info(
                        "Loading an AWQ model. This requires repacking the weights, and no repacking cached weight was found. Grab a coffee!"
                    )

                    if "safetensors" not in model_save_name:
                        raise NotImplementedError(
                            f"Conversion from AWQ checkpoints is implemented only for safetensors checkpoints, found {model_save_name}"
                        )
                    if quantize_config.bits != 4:
                        raise NotImplementedError(
                            f"Conversion from AWQ checkpoints is supported only for 4 bits models. Found {quantize_config.bits} bits."
                        )
                    gptq_layers = set()
                    non_gptq_params = set()
                    with safe_open(model_save_name, framework="pt") as f:
                        for state_dict_key in f.keys():
                            if (
                                "qweight" not in state_dict_key
                                and "qzeros" not in state_dict_key
                                and "scales" not in state_dict_key
                            ):
                                non_gptq_params.add(state_dict_key)
                                continue

                            # e.g. prefix "model.layers.3.self_attn.k_proj"
                            prefix, _ = state_dict_key.rsplit(".", 1)
                            gptq_layers.add(prefix)

                        new_state_dict = {}

                        for state_dict_key in non_gptq_params:
                            new_state_dict[state_dict_key] = f.get_tensor(state_dict_key)

                        gptq_layers = sorted(gptq_layers)
                        max_layer_name_length = len(max(gptq_layers, key=len))
                        pbar = tqdm(gptq_layers)
                        i = 0
                        for gptq_layer_name in pbar:
                            i += 1
                            desc = f"Unpacking {gptq_layer_name} + '...'"
                            desc = desc + " " * (max_layer_name_length - len(desc))

                            awq_qweight = f.get_tensor(gptq_layer_name + ".qweight")
                            awq_qzeros = f.get_tensor(gptq_layer_name + ".qzeros")
                            awq_scales = f.get_tensor(gptq_layer_name + ".scales")

                            # TODO: add FAST unpacking.
                            unpacked_qweight, unpacked_qzeros = unpack_awq(
                                awq_qweight,
                                awq_qzeros,
                                awq_scales,
                                bits=quantize_config.bits,
                                group_size=quantize_config.group_size,
                            )

                            # TODO: add FAST repacking, this is too slow.
                            desc = f"Repacking {gptq_layer_name}..."
                            desc = desc + " " * (max_layer_name_length + 12 - len(desc))
                            pbar.set_description(desc)
                            gptq_qweight, gptq_qzeros = pack_from_tensors(
                                unpacked_qweight,
                                unpacked_qzeros,
                                awq_scales,
                                bits=quantize_config.bits,
                                group_size=quantize_config.group_size,
                            )

                            new_state_dict[gptq_layer_name + ".qweight"] = gptq_qweight
                            new_state_dict[gptq_layer_name + ".qzeros"] = gptq_qzeros
                            new_state_dict[gptq_layer_name + ".scales"] = awq_scales

                    # Cache the converted model.
                    if is_local:
                        model_save_name = os.path.join(model_name_or_path, "autogptq_model.safetensors")
                        safe_save(new_state_dict, model_save_name)
                    else:
                        namespace, subfolder = model_name_or_path.split("/")
                        assets_path = huggingface_hub.cached_assets_path(
                            library_name="autogptq",
                            namespace=namespace,
                            subfolder=subfolder,
                        )
                        model_save_name = os.path.join(assets_path, "autogptq_model.safetensors")

                        safe_save(new_state_dict, model_save_name)

            if use_marlin:
                if is_sharded:
                    raise ValueError("The loading of sharded checkpoints with Marlin is currently not supported. Please raise an issue in AutoGPTQ repository.")
                if torch.version.hip:
                    raise ValueError("Can not use Marlin int4*fp16 kernel with AMD ROCm version of PyTorch as the kernel is not compatible. Please do not use `use_marlin=True` when using ROCm devices.")
                if not torch.cuda.get_device_capability()[0] >= 8:
                    raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `use_marlin=True`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).')

                # Validate the model can run in Marlin.
                if torch_dtype != torch.float16:
                    raise ValueError("Marlin kernel requires torch_dtype=torch.float16.")
                unsupported_reason = _validate_marlin_compatibility(quantize_config)
                if unsupported_reason is not None:
                    raise ValueError(
                        f"The model {model_name_or_path} can not be converted to use the Marlin kernel for the following reason: {unsupported_reason}, which is not supported by Marlin kernel."
                    )

                # Load the quant linear type we need.
                # TODO: load directy marlin with the right quantlinear class.
                quant_linear_class = dynamically_import_QuantLinear(
                    use_triton=use_triton,
                    desc_act=quantize_config.desc_act,
                    group_size=quantize_config.group_size,
                    bits=quantize_config.bits,
                    disable_exllama=disable_exllama,
                    disable_exllamav2=disable_exllamav2,
                    disable_marlin=True,  # Get the "original" QuantLienar class
                )

                # Prepare model for marlin load.
                #   If stub is marlin serialzed         --> load from directly
                #   If stub has cached marlin version   --> load from the cached versin
                #   Otherwise                           --> convert to marlin, cache, load from cache
                model, model_save_name = prepare_model_for_marlin_load(
                    model_name_or_path=model_name_or_path,
                    model=model,
                    quantize_config=quantize_config,
                    quant_linear_class=quant_linear_class,
                    torch_dtype=torch_dtype,
                    current_model_save_name=model_save_name,
                    device_map=device_map,
                )

                # Disable incompatible optimizations.
                if inject_fused_attention or inject_fused_mlp:
                    # TODO: Validate whether that can be used.
                    logger.info("Disabling fused attention and mlp injection because Marlin kernel is used.")
                    inject_fused_attention = False
                    inject_fused_mlp = False

            accelerate.utils.modeling.load_checkpoint_in_model(
                model,
                dtype=torch_dtype,  # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
                checkpoint=model_save_name,
                device_map=device_map,
                offload_state_dict=True,
                offload_buffers=True,
            )

            # TODO: Why are we using this custom function and not dispatch_model?
            model = simple_dispatch_model(model, device_map)
        else:
            # Using QiGen.

            if is_sharded:
                raise ValueError("The loading of sharded checkpoints with QiGen is currently not supported. Please raise an issue in AutoGPTQ repository.")

            if quantize_config.desc_act:
                NotImplementedError("desc_act=True is not yet supported with QiGen.")
            model = AutoModelForCausalLM.from_config(
                config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
            )

            layers = find_layers(model)
            ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
            for name in list(layers.keys()):
                if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers):
                    logger.info(f"{name} not been quantized, will be ignored when make_quant.")
                    del layers[name]

            if model_save_name.endswith(".safetensors"):
                checkpoint = safe_load(model_save_name)
            else:
                checkpoint = torch.load(model_save_name)
            make_quant(
                model,
                layers,
                quantize_config.bits,
                quantize_config.group_size,
                use_triton=use_triton,
                disable_exllama=disable_exllama,
                disable_exllamav2=disable_exllamav2,
                use_cuda_fp16=use_cuda_fp16,
                desc_act=quantize_config.desc_act,
                trainable=trainable,
                use_qigen=True,
            )
            preprocess_checkpoint_qigen(
                model,
                layers,
                quantize_config.bits,
                quantize_config.group_size,
                checkpoint,
            )
            model.load_state_dict(checkpoint)

        # == step4: set seqlen == #
        model_config = model.config.to_dict()
        seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
        if any(k in model_config for k in seq_len_keys):
            for key in seq_len_keys:
                if key in model_config:
                    model.seqlen = model_config[key]
                    break
        else:
            logger.warning("can't get model's sequence length from model config, will set to 4096.")
            model.seqlen = 4096

        # == step5: (optional) inject optimized module == #
        if inject_fused_attention:
            if cls.fused_attn_module_type is None:
                inject_fused_attention = False
                logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
            else:
                cls.fused_attn_module_type.inject_to_model(
                    model,
                    use_triton=use_triton,
                    group_size=quantize_config.group_size,
                    use_cuda_fp16=use_cuda_fp16,
                    desc_act=quantize_config.desc_act,
                    trainable=trainable,
                    bits=quantize_config.bits,
                    disable_exllama=disable_exllama,
                    disable_exllamav2=disable_exllamav2,
                )
        if inject_fused_mlp:
            if cls.fused_mlp_module_type is None:
                inject_fused_mlp = False
                logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
            else:
                cls.fused_mlp_module_type.inject_to_model(model, use_triton=use_triton)

        # Any post-initialization that require device information, for example buffers initialization on device.
        model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)

        model.eval()

        # == step6: (optional) warmup triton == #
        if use_triton and warmup_triton:
            from ..nn_modules.qlinear.qlinear_triton import QuantLinear

            QuantLinear.warmup(model, seqlen=model.seqlen)

            if inject_fused_mlp and cls.fused_mlp_module_type is not None:
                cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)

        # == step7: make model compatible with peft
        # cls.make_sure_compatible_with_peft(
        #     model,
        #     use_triton,
        #     quantize_config.desc_act,
        #     quantize_config.group_size,
        #     bits=quantize_config.bits,
        #     disable_exllama=disable_exllama,
        #     disable_exllamav2=disable_exllamav2,
        #     use_marlin=use_marlin,
        #     use_qigen=use_qigen,
        # )

        return cls(
            model,
            True,
            quantize_config,
            is_triton_backend=use_triton,
            injected_fused_attention=inject_fused_attention,
            injected_fused_mlp=inject_fused_mlp and use_triton,
            trainable=trainable,
        )

    def warmup_triton(self, enabled: bool = True):
        if not enabled:
            return
        if not TRITON_AVAILABLE:
            logger.warning("triton is not available, skip warmup stage directly.")
            return

        from ..nn_modules.qlinear.qlinear_triton import QuantLinear

        QuantLinear.warmup(self.model, seqlen=self.model.seqlen)

        if self.fused_mlp_module_type is not None:
            self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)

    def enable_trainable_mode(self, enabled: bool = True):
        if not self.is_triton_backend and enabled:
            raise NotImplementedError("For now, trainable mode only supports triton backend.")
        for n, m in self.model.named_modules():
            if hasattr(m, "trainable"):
                setattr(m, "trainable", enabled)

    def disable_trainable_mode(self):
        self.enable_trainable_mode(enabled=False)

    @staticmethod
    def make_sure_compatible_with_peft(
        model: PreTrainedModel,
        use_triton: bool,
        desc_act: bool,
        group_size: int,
        bits: int,
        disable_exllama: bool = True,
        disable_exllamav2: bool = False,
        use_marlin: bool = False,
        use_qigen: bool = False,
    ):
        GeneralQuantLinear.inject_to_model(
            model,
            dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, disable_marlin=not use_marlin, use_qigen=use_qigen),
        )

    def __getattr__(self, item):
        try:
            return super().__getattr__(item)
        except Exception:
            return getattr(self.model, item)


__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]
