# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to PEFT, modified mainly from those in:
peft/tuners/tuners_utils.py
peft/tuners/lora/layer.py
peft/tuners/lora/bnb.py
peft/tuners/lora/model.py
peft/peft_model.py
"""
import os
import enum
from contextlib import contextmanager
from dataclasses import dataclass, field
from huggingface_hub import hf_hub_download
import inspect
from itertools import chain
import re
from typing import List, Optional, Union, Dict
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from peft import (
    PeftModel, PeftConfig, PromptLearningConfig,
    PeftModelForCausalLM,
    PeftType,  # defines model name to str mapping, e.g., PeftType.LORA = "LORA"
    PEFT_TYPE_TO_CONFIG_MAPPING,
)
from peft.import_utils import is_bnb_available, is_bnb_4bit_available
from peft.mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.tuners import (
    LoraModel,
    LoraConfig,
)
from peft.tuners.lora import (
    Conv2d as LoraConv2d, 
    Embedding as LoraEmbedding, 
    Linear as LoraLinear, 
    LoraLayer,
    QuantLinear as LoraQuantLinear
)
from peft.tuners.tuners_utils import BaseTuner
from peft.utils import (
    _set_adapter, 
    set_peft_model_state_dict, 
    _prepare_prompt_learning_config,
    ModulesToSaveWrapper,
    infer_device,
    load_peft_weights,
    get_auto_gptq_quant_linear,
    get_quantization_config,
    CONFIG_NAME,
    _get_batch_size,
)
if is_bnb_available():
    import bitsandbytes as bnb

    from peft.tuners.lora import Linear8bitLt as LoraLinear8bitLt
if is_bnb_4bit_available():
    from peft.tuners.lora import Linear4bit as LoraLinear4bit

from transformers import PreTrainedModel
from transformers.pytorch_utils import Conv1D


@dataclass
class MoLoraConfig(LoraConfig):
    # Add layers_to_transform and change its type to avoid the following error in hf_argparser.py
    # Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument 
    # parser only supports one type per argument. Problem encountered in field 'layers_to_transform'.
    layers_to_transform: str = field(
        default=None,
        metadata={
            "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index. "
            "This only works when target_modules is a list of str."
        },
    )
    molora_strategy: str = field(
        default="average",
        metadata={
            "help": "Strategy used with mixture sets of Lora. Support: 'parallel', 'average'. "
            "'parallel': given k sets of lora, the batch is split into k sub-batches, each passed to one lora, and concatenated in the end."
            "'average': given k sets of lora, their weights is averaged."
        },
    )

    def __post_init__(self):
        super().__post_init__()
        if isinstance(self.layers_to_transform, str):
            if self.layers_to_transform == "None":
                self.layers_to_transform = None
            else:
                self.layers_to_transform = eval(self.layers_to_transform)
                assert (
                    isinstance(self.layers_to_transform, int)
                    or (
                        isinstance(self.layers_to_transform, list)
                        and all(isinstance(layer, int) for layer in self.layers_to_transform)
                    )
                ), TypeError(
                    "Expect layers_to_transform to be int or a list of ints; got {}".format(self.layers_to_transform))

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: str = None, **kwargs):
        r"""
        This method loads the configuration of your adapter model from a directory.

        Args:
            pretrained_model_name_or_path (`str`):
                The directory or the Hub repository id where the configuration is saved.
            kwargs (additional keyword arguments, *optional*):
                Additional keyword arguments passed along to the child class initialization.
        """
        path = (
            os.path.join(pretrained_model_name_or_path, subfolder)
            if subfolder is not None
            else pretrained_model_name_or_path
        )

        hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)

        if os.path.isfile(os.path.join(path, CONFIG_NAME)):
            config_file = os.path.join(path, CONFIG_NAME)
        else:
            try:
                config_file = hf_hub_download(
                    pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
                )
            except Exception:
                raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")

        loaded_attributes = cls.from_json_file(config_file)

        if "peft_type" in loaded_attributes:
            peft_type = loaded_attributes["peft_type"]
            config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
        else:
            config_cls = cls

        config = config_cls(**class_kwargs)
        for key, value in loaded_attributes.items():
            if hasattr(config, key):
                setattr(config, key, value)

        return config


class MoLinear(nn.Linear, LoraLayer):
    """This class implements a dense lora layer similar to peft.tunners.lora.Linear, except
    it allows multiple lora weights of potentially different hyperaprameters to co-exists. 
    At forward, different sets of lora weights are used according to a certain MoLora strategy.

    Args:
        ...
        molora_strategy:
            "parallel": given k sets of lora, the batch is split into k sub-batches, each passed 
            to one lora, and concatenated in the end.
            "average": given k sets of lora, their weights is averaged.
    """
    MoLoraStrategy = ['parallel', 'average']
    get_delta_weight = LoraLinear.get_delta_weight
    _linear = LoraLinear._linear

    def __init__(
        self,
        adapter_name: Union[str, List[str]],
        in_features: int,
        out_features: int,
        r: Union[int, List[int]] = 0,
        lora_alpha: Union[int, List[int]] = 1,
        lora_dropout: Union[float, List[float]] = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        is_target_conv_1d_layer: bool = False,
        **kwargs,
    ):
        self.molora_strategy = kwargs.pop('molora_strategy', 'average')
        assert self.molora_strategy in self.MoLoraStrategy, NotImplementedError(
            "Expect molora_strategy to be one of {}; got {}".format(self.MoLoraStrategy, self.molora_strategy))
        init_lora_weights = kwargs.pop("init_lora_weights", True)
        # this gets the init from nn.Linear's super perspective, i.e.
        # nn.Module.__init__, which should always be called
        super(nn.Linear, self).__init__()
        # Note that we don't use self._init_empty_weights() for Linear because it is a bit slower and the benefit of
        # added robustness is not big enough for Linear.

        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)

        self.fan_in_fan_out = fan_in_fan_out

        if isinstance(adapter_name, str):
            adapter_name = [adapter_name]
        assert isinstance(adapter_name, list), TypeError(
            "Expect adapter_name to be str or List(str); got {}".format(type(adapter_name)))
        num_lora = len(adapter_name)
        if isinstance(r, int):
            r = [r] * num_lora
        if isinstance(lora_alpha, int):
            lora_alpha = [lora_alpha] * num_lora
        if isinstance(lora_dropout, float):
            lora_dropout = [lora_dropout] * num_lora
        for _adapter_name, _r, _lora_alpha, _lora_dropout in zip(adapter_name, r, lora_alpha, lora_dropout):
            self.update_layer(_adapter_name, _r, _lora_alpha, _lora_dropout, init_lora_weights)

        self.is_target_conv_1d_layer = is_target_conv_1d_layer
        self.set_adapter(adapter_name)  # this will freeze the pre-trained weight matrix
    
    @property
    def merged(self) -> bool:
        return bool(self.merged_adapters)
            
    def merge(self, safe_merge: bool = False) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If True, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
        """
        if self.merged:
            warnings.warn(
                f"Already following adapters were merged {','.join(self.merged_adapters)}. "
                f"You are now additionally merging {','.join(self.active_adapter)}."
            )
        active_adapter = [
            adapter_name for adapter_name in self.active_adapter if (
                self.r.get(adapter_name, 0) > 0 and adapter_name in self.lora_A
            )]
        if active_adapter:
            if len(active_adapter) > 1 and self.molora_strategy == 'parallel':
                warnings.warn("Lora weights cannot be merged for molora_strategy parallel.")
                return
            lora_weights = [self.get_delta_weight(adapter_name) for adapter_name in active_adapter]
            if len(lora_weights) > 1:  # corresponds to average molora_strategy
                lora_weights = torch.mean(torch.stack(lora_weights, dim=0), dim=0)
            else:
                lora_weights = lora_weights[0]
            # Note that safe_merge will be slower than the normal merge
            # because of the out-place addition and nan check.
            if safe_merge:
                merged_weights = self.weight.data + lora_weights
                assert torch.isfinite(merged_weights).all(), ValueError(
                    f"NaNs detected in the merged weights. Some adapter seems to be broken"
                )
                self.weight.data = merged_weights
            else:
                self.weight.data += lora_weights
            self.merged_adapters.extend(active_adapter)

    def unmerge(self):
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        active_adapter = [
            adapter_name for adapter_name in self.active_adapter if (
                self.r.get(adapter_name, 0) > 0 
                and adapter_name in self.lora_A 
                and adapter_name in self.merged_adapters
            )]
        if active_adapter:
            if self.molora_strategy == 'parallel':
                warnings.warn("Lora weights cannot be unmerged for molora_strategy parallel.")
                return
            lora_weights = [self.get_delta_weight(adapter_name) for adapter_name in active_adapter]
            if len(lora_weights) > 1:
                lora_weights = torch.mean(torch.stack(lora_weights, dim=0), dim=0)
            else:
                lora_weights = lora_weights[0]
            self.weight.data -= lora_weights
            self.merged_adapters = [
                adapter_name for adapter_name in self.merged_adapters if adapter_name not in active_adapter]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self._linear(x)
        elif self.merged:
            result = self._linear(x)
        else:
            result = self._linear(x)
            
            active_adapter = [
                adapter_name for adapter_name in self.active_adapter if (
                    self.r.get(adapter_name, 0) > 0 and adapter_name in self.lora_A
                )]
            num_adapter = len(active_adapter)
            if num_adapter == 0:
                raise RuntimeError("No activate LoRA adapters.")
            elif num_adapter == 1:
                adapter_name = active_adapter[0]
                x = x.to(self.lora_A[adapter_name].weight.dtype)

                result += (
                    self.lora_B[adapter_name](
                        self.lora_A[adapter_name](self.lora_dropout[adapter_name](x))
                    )
                    * self.scaling[adapter_name]
                )
            elif self.molora_strategy == 'parallel':
                chunk_size = x.shape[0]//num_adapter
                x = list(torch.split(x, chunk_size, dim=0))
                result = list(torch.split(result, chunk_size, dim=0))
                for index, adapter_name in enumerate(active_adapter):
                    if self.r[adapter_name] > 0:  # somehow lora is in float32
                        x[index] = x[index].to(self.lora_A[adapter_name].weight.dtype)

                        result[index] += (
                            self.lora_B[adapter_name](
                                self.lora_A[adapter_name](self.lora_dropout[adapter_name](x[index]))
                            )
                            * self.scaling[adapter_name]
                        )
                result = torch.cat(result, dim=0)
            elif self.molora_strategy == 'average':
                lora_results = []
                for adapter_name in active_adapter:
                    if self.r[adapter_name] > 0:  # somehow lora is in float32
                        x = x.to(self.lora_A[adapter_name].weight.dtype)

                        lora_results.append(
                            self.lora_B[adapter_name](
                                self.lora_A[adapter_name](self.lora_dropout[adapter_name](x))
                            )
                            * self.scaling[adapter_name]
                        )
                result += torch.stack(lora_results, dim=0).mean(dim=0)
            else:
                raise NotImplementedError(
                    "molora_strategy {} has not been implemented.".format(self.molora_strategy))

        result = result.to(previous_dtype)

        return result


if is_bnb_available():

    class MoLinear8bitLt(torch.nn.Module, LoraLayer):
        """This class implements a dense lora layer similar to peft.tunners.lora.Linear8bitLt, except
        it allows multiple lora weights of potentially different hyperaprameters to co-exists. 
        At forward, different sets of lora weights are used according to a certain MoLora.

        Unlike LoraLinear, Linear8bitLt is a wrapper around bnb.nn.Linear8bitLt.
        """
        # Lora implemented in a dense layer
        MoLoraStrategy = ['parallel', 'average']
        get_delta_weight = LoraLinear.get_delta_weight

        def __init__(
            self,
            adapter_name: Union[str, List[str]],
            base_layer,
            r: Union[int, List[int]] = 0,
            lora_alpha: Union[int, List[int]] = 1,
            lora_dropout: Union[float, List[float]] = 0.0,
            **kwargs,
        ):
            self.molora_strategy = kwargs.pop('molora_strategy', 'parallel')
            assert self.molora_strategy in self.MoLoraStrategy, NotImplementedError(
                "Expect molora_strategy to be one of {}; got {}".format(self.MoLoraStrategy, self.molora_strategy))
        
            super().__init__()
            LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features)
            self.base_layer = base_layer

            init_lora_weights = kwargs.pop("init_lora_weights", True)
            if isinstance(adapter_name, str):
                adapter_name = [adapter_name]
            assert isinstance(adapter_name, list), TypeError(
                "Expect adapter_name to be str or List(str); got {}".format(type(adapter_name)))
            num_lora = len(adapter_name)
            if isinstance(r, int):
                r = [r] * num_lora
            if isinstance(lora_alpha, int):
                lora_alpha = [lora_alpha] * num_lora
            if isinstance(lora_dropout, float):
                lora_dropout = [lora_dropout] * num_lora
            for _adapter_name, _r, _lora_alpha, _lora_dropout in zip(adapter_name, r, lora_alpha, lora_dropout):
                self.update_layer(_adapter_name, _r, _lora_alpha, _lora_dropout, init_lora_weights)

            self.set_adapter(adapter_name)

        def _dequantize_base_layer_weight(self, dtype, device):
            weight = self.base_layer.weight
            state = self.base_layer.state
            if state.SCB is None:
                state.SCB = weight.SCB

            # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
            # dequantization directly
            im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
            im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
            im, Sim = bnb.functional.transform(im, "col32")
            if state.CxB is None:
                state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
            out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
            output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()

            return output.to(dtype).to(device)

        def merge(self, safe_merge: bool = False):
            """
            Merge the active adapter weights into the base weights

            Args:
                safe_merge (`bool`, *optional*):
                    If True, the merge operation will be performed in a copy of the original weights and check for NaNs
                    before merging the weights. This is useful if you want to check if the merge operation will produce
                    NaNs. Defaults to `False`.
            """
            if self.merged:
                warnings.warn(
                    f"Already following adapters were merged {','.join(self.merged_adapters)}. "
                    f"You are now additionally merging {','.join(self.active_adapter)}."
                )

            active_adapter = [
                adapter_name for adapter_name in self.active_adapter if (
                    self.r.get(adapter_name, 0) > 0 and adapter_name in self.lora_A)]
            if active_adapter:
                warnings.warn(
                    "Merge lora module to 8-bit linear may get different generations due to rounding errors."
                )
                if len(active_adapter) > 1 and self.molora_strategy == 'parallel':
                    warnings.warn("Lora weights cannot be merged for molora_strategy parallel.")
                    return
                lora_weights = [self.get_delta_weight(adapter_name) for adapter_name in active_adapter]
                if len(lora_weights) > 1:  # corresponds to average molora_strategy
                    lora_weights = torch.mean(torch.stack(lora_weights, dim=0), dim=0)
                else:
                    lora_weights = lora_weights[0]
                
                w_data = self._dequantize_base_layer_weight(lora_weights.dtype, lora_weights.device)
                w_data = w_data + lora_weights

                if safe_merge and not torch.isfinite(w_data).all():
                    raise ValueError(
                        f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                    )

                self.base_layer.weight = bnb.nn.Int8Params(
                    w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.base_layer.weight.has_fp16_weights
                ).to(self.base_layer.weight.device)
                self.base_layer.state.reset_grads()
                self.merged_adapters.extend(active_adapter)
        
        def unmerge(self):
            if not self.merged:
                warnings.warn("Already unmerged. Nothing to do.")
                return
            active_adapter = [
                adapter_name for adapter_name in self.active_adapter if (
                    self.r.get(adapter_name, 0) > 0 
                    and adapter_name in self.lora_A 
                    and adapter_name in self.merged_adapters
            )]
            if active_adapter:
                if self.molora_strategy == 'parallel':
                    warnings.warn("Lora weights cannot be unmerged for molora_strategy parallel.")
                    return
                lora_weights = [self.get_delta_weight(adapter_name) for adapter_name in active_adapter]
                if len(lora_weights) > 1:
                    lora_weights = torch.mean(torch.stack(lora_weights, dim=0), dim=0)
                else:
                    lora_weights = lora_weights[0]
                w_data = self._dequantize_base_layer_weight(lora_weights.dtype, lora_weights.device)
                w_data = w_data - lora_weights

                self.base_layer.weight = bnb.nn.Int8Params(
                    w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.base_layer.weight.has_fp16_weights
                ).to(self.base_layer.weight.device)
                self.base_layer.state.reset_grads()
                self.merged_adapters = [
                    adapter_name for adapter_name in self.merged_adapters if adapter_name not in active_adapter]     

        def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            previous_dtype = x.dtype

            if self.disable_adapters:
                if self.merged:
                    self.unmerge()
                result = self.base_layer(x)
            elif self.merged:
                result = self.base_layer(x, *args, **kwargs)
            else:
                result = self.base_layer(x, *args, **kwargs)

                requires_conversion = not torch.is_autocast_enabled()
                active_adapter = [
                    adapter_name for adapter_name in self.active_adapter if (
                        self.r.get(adapter_name, 0) > 0 and adapter_name in self.lora_A
                )]

                num_adapter = len(active_adapter)
                if num_adapter == 0:
                    raise RuntimeError("No activate LoRA adapters.")
                elif num_adapter == 1:
                    adapter_name = active_adapter[0]
                    if requires_conversion:
                        x = x.to(self.lora_A[adapter_name].weight.dtype)

                    output = (
                        self.lora_B[adapter_name](
                            self.lora_A[adapter_name](self.lora_dropout[adapter_name](x))
                        )
                        * self.scaling[adapter_name]
                    )
                    if requires_conversion:
                        output = output.to(result.dtype)
                    result += output
                elif self.molora_strategy == 'parallel':
                    chunk_size = x.shape[0]//num_adapter
                    x = list(torch.split(x, chunk_size, dim=0))
                    result = list(torch.split(result, chunk_size, dim=0))
                    for index, adapter_name in enumerate(active_adapter):
                        if self.r[adapter_name] > 0:  # somehow lora is in float32
                            if requires_conversion:
                                x[index] = x[index].to(self.lora_A[adapter_name].weight.dtype)

                            output = (
                                self.lora_B[adapter_name](
                                    self.lora_A[adapter_name](self.lora_dropout[adapter_name](x[index]))
                                )
                                * self.scaling[adapter_name]
                            )

                            if requires_conversion:
                                output = output.to(result.dtype)
                            result[index] += output
                    result = torch.cat(result, dim=0)
                elif self.molora_strategy == 'average':
                    lora_results = []
                    for adapter_name in active_adapter:
                        if requires_conversion:
                            x = x.to(self.lora_A[adapter_name].weight.dtype)

                        lora_results.append(
                            self.lora_B[adapter_name](
                                self.lora_A[adapter_name](self.lora_dropout[adapter_name](x))
                            )
                            * self.scaling[adapter_name]
                        )

                    output = torch.stack(lora_results, dim=0).mean(dim=0)
                    if requires_conversion:
                        output = output.to(result.dtype)
                    result += output
                else:
                    raise NotImplementedError(
                        "molora_strategy {} has not been implemented.".format(self.molora_strategy))

            result = result.to(previous_dtype)

            return result

    
class MoLoraModel(LoraModel):
    """
    Creates Low Rank Adapter (Lora) model with multiple Lora sets from a pretrained transformers model.

    Args:
        model ([`~transformers.PreTrainedModel`]): The model to be adapted.
        config ([`LoraConfig`]): The configuration of the Lora model.

    Returns:
        `torch.nn.Module`: The Lora model.

    Example:

        ```py
        >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig
        >>> from peft import LoraModel, LoraConfig

        >>> config = LoraConfig(
        ...     peft_type="LORA",
        ...     task_type="SEQ_2_SEQ_LM",
        ...     r=8,
        ...     lora_alpha=32,
        ...     target_modules=["q", "v"],
        ...     lora_dropout=0.01,
        ...     molora_strategy="parallel",
        ... )

        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> adapter_name = ["s0", "s1", "s2"]  # Lora with 3 sets of parameters
        >>> lora_model = MoLoraModel(config, model, adapter_name)
        ```

    **Attributes**:
        - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
        - **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
    """
    def __init__(self, model, config: dict, adapter_name: Union[str, List[str]]):
        super(BaseTuner, self).__init__()
        self.model = model
        self.forward = self.model.forward
        self.peft_config = {adapter_name: config} if isinstance(config, PeftConfig) else config

        if isinstance(adapter_name, str):
            adapter_name = [adapter_name]
        assert (
            isinstance(adapter_name, list) 
            and all(name in self.peft_config for name in adapter_name)
        ), ValueError()
        self.active_adapter = adapter_name

        # transformers models have a .config attribute, whose presence is assumed later on
        if not hasattr(self, "config"):
            self.config = {"model_type": "custom"}

        for _adapter_name in adapter_name:
            self.inject_adapter(self.model, _adapter_name)

        # Copy the peft_config in the injected model.
        self.model.peft_config = self.peft_config

    def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False, safe_merge: bool = False):
        r"""
        This method removes the lora modules and optionally merges them into the base model. 
        This is needed if someone wants to use the base model as a standalone model.

        Args:
            progressbar (`bool`):
                whether to show a progressbar indicating the unload and merge process
            safe_merge (`bool`):
                whether to activate the safe merging check to check if there is any potential Nan in the adapter
                weights
        """
        if merge:
            if getattr(self.model, "quantization_method", None) == "gptq":
                raise ValueError("Cannot merge LORA layers when the model is gptq quantized")

        key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
        desc = "Unloading " + ("and merging " if merge else "") + "model"
        for key in tqdm(key_list, disable=not progressbar, desc=desc):
            try:
                parent, target, target_name = _get_submodules(self.model, key)
            except AttributeError:
                continue
            if isinstance(target, LoraLayer):
                if isinstance(target, nn.Embedding):
                    new_module = torch.nn.Embedding(target.in_features, target.out_features)
                elif isinstance(target, nn.Conv2d):
                    new_module = torch.nn.Conv2d(
                        target.in_channels,
                        target.out_channels,
                        kernel_size=target.kernel_size,
                        stride=target.stride,
                        padding=target.padding,
                        dilation=target.dilation,
                    )
                elif is_bnb_available() and isinstance(target, LoraLinear8bitLt):
                    bias = target.base_layer.bias is not None
                    new_module = bnb.nn.Linear8bitLt(
                        target.in_features,
                        target.out_features,
                        bias=bias,
                        has_fp16_weights=target.base_layer.state.has_fp16_weights,
                        memory_efficient_backward=target.base_layer.state.memory_efficient_backward,
                        threshold=target.base_layer.state.threshold,
                        index=target.base_layer.index,
                        device=target.base_layer.weight.device,
                    )
                elif is_bnb_4bit_available() and isinstance(target, LoraLinear4bit):
                    bias = target.base_layer.bias is not None
                    new_module = bnb.nn.Linear4bit(
                        target.in_features,
                        target.out_features,
                        bias=bias,
                        compute_dtype=target.base_layer.compute_dtype,
                        compress_statistics=target.base_layer.weight.compress_statistics,
                        quant_type=target.base_layer.weight.quant_type,
                        device=target.base_layer.weight.device,
                    )
                else:
                    bias = target.bias is not None
                    if getattr(target, "is_target_conv_1d_layer", False):
                        new_module = Conv1D(target.out_features, target.in_features)
                    else:
                        new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
                    if merge:
                        target.merge(safe_merge=safe_merge)
                    self._replace_module(parent, target_name, new_module, target)

            # save any additional trainable modules part of `modules_to_save`
            if isinstance(target, ModulesToSaveWrapper):
                if isinstance(target.active_adapter, str):
                    modules_to_save = target.modules_to_save[target.active_adapter]
                else: # list
                    modules_to_save = torch.nn.ModuleDict({
                        adapter_name: target.modules_to_save[adapter_name] for adapter_name in target.active_adapter})
                setattr(parent, target_name, modules_to_save)

        return self.model

    @staticmethod
    def _create_new_module(lora_config, adapter_name, target, **kwargs):
        gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
        AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

        loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
        loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
        bias = kwargs.pop("bias", False)

        if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
            eightbit_kwargs = kwargs.copy()
            eightbit_kwargs.update(
                {
                    "has_fp16_weights": target.state.has_fp16_weights,
                    "memory_efficient_backward": target.state.memory_efficient_backward,
                    "threshold": target.state.threshold,
                    "index": target.index,
                    "molora_strategy": getattr(lora_config, "molora_strategy", "average")
                }
            )
            new_module = MoLinear8bitLt(adapter_name, target, **eightbit_kwargs)
        elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
            fourbit_kwargs = kwargs.copy()
            fourbit_kwargs.update(
                {
                    "compute_dtype": target.compute_dtype,
                    "compress_statistics": target.weight.compress_statistics,
                    "quant_type": target.weight.quant_type,
                }
            )
            # TODO implement MoLinear4bit
            new_module = LoraLinear4bit(adapter_name, target, **fourbit_kwargs)
        elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
            # TODO implement MoLinear4bit
            new_module = LoraQuantLinear(adapter_name, target, **kwargs)
            target.weight = target.qweight
        elif isinstance(target, torch.nn.Embedding):
            # TODO implement MoLoraEmbedding
            embedding_kwargs = kwargs.copy()
            embedding_kwargs.pop("fan_in_fan_out", None)
            in_features, out_features = target.num_embeddings, target.embedding_dim
            new_module = LoraEmbedding(adapter_name, in_features, out_features, **embedding_kwargs)
        elif isinstance(target, torch.nn.Conv2d):
            # TODO implement MoLoraConv2d
            out_channels, in_channels = target.weight.size()[:2]
            kernel_size = target.weight.size()[2:]
            stride = target.stride
            padding = target.padding
            new_module = LoraConv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
        else:
            if isinstance(target, torch.nn.Linear):
                in_features, out_features = target.in_features, target.out_features
                if kwargs["fan_in_fan_out"]:
                    warnings.warn(
                        "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                        "Setting fan_in_fan_out to False."
                    )
                    kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
            elif isinstance(target, Conv1D):
                in_features, out_features = (
                    target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
                )
                kwargs["is_target_conv_1d_layer"] = True
                if not kwargs["fan_in_fan_out"]:
                    warnings.warn(
                        "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                        "Setting fan_in_fan_out to True."
                    )
                    kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
            else:
                raise ValueError(
                    f"Target module {target} is not supported. "
                    f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
                )
            new_module = MoLinear(adapter_name, in_features, out_features, bias=bias, **kwargs)

        return new_module

    def _create_and_replace(
        self,
        lora_config,
        adapter_name,
        target,
        target_name,
        parent,
        current_key,
        **optional_kwargs,
    ):
        r"""
        Inplace replacement of the target module with the adapter layer. This method needs to be overriden by all the
        tuner classes.

        Check `peft.tuners.lora.LoraModel._create_and_replace` for an example.

        Args:
            peft_config (`PeftConfig`):
                The adapter config.
            adapter_name (`str`):
                The adapter name.
            target (`nn.Module`):
                The target module.
            target_name (`str`):
                The target module's name.
            parent (`nn.Module`):
                The parent module.
            current_key (`nn.Module`):
                Name of the current module to be considered.
            **optional_kwargs (`dict`):
                The optional keyword arguments to pass to deal with particular cases (e.g. 8bit, 4bit quantization)
        """
        if current_key is None:
            raise ValueError("Current Key shouldn't be `None`")
        # Regexp matching - Find key which matches current target_name in patterns provided
        pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
        target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), current_key)

        r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
        alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
        bias = hasattr(target, "bias") and target.bias is not None
        kwargs = {
            "r": r,
            "lora_alpha": alpha,
            "lora_dropout": lora_config.lora_dropout,
            "fan_in_fan_out": lora_config.fan_in_fan_out,
            "init_lora_weights": lora_config.init_lora_weights,
            "molora_strategy": getattr(lora_config, "molora_strategy", "average")
        }
        kwargs["loaded_in_8bit"] = optional_kwargs.pop("loaded_in_8bit", False)
        kwargs["loaded_in_4bit"] = optional_kwargs.pop("loaded_in_4bit", False)
        kwargs["bias"] = bias

        quantization_config = get_quantization_config(self.model, method="gptq")
        if quantization_config is not None:
            kwargs["gptq_quantization_config"] = quantization_config

        if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d):
            if adapter_name not in target._active_adapter:
                target._active_adapter.append(adapter_name)
            target.update_layer_conv2d(
                adapter_name,
                r,
                alpha,
                lora_config.lora_dropout,
                lora_config.init_lora_weights,
            )
        elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding):
            if adapter_name not in target._active_adapter:
                target._active_adapter.append(adapter_name)
            target.update_layer_embedding(
                adapter_name,
                r,
                alpha,
                lora_config.lora_dropout,
                lora_config.init_lora_weights,
            )

        elif isinstance(target, LoraLayer):
            if adapter_name not in target._active_adapter:
                target._active_adapter.append(adapter_name)
            target.update_layer(
                adapter_name,
                r,
                alpha,
                lora_config.lora_dropout,
                lora_config.init_lora_weights,
            )
        else:
            new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
            if adapter_name not in self.active_adapter:  # self.active_adapter is a list
                # adding an additional adapter: it is not automatically trainable
                new_module.requires_grad_(False)
            self._replace_module(parent, target_name, new_module, target)


PEFT_TYPE_TO_MODEL_MAPPING.update({
    PeftType.LORA: MoLoraModel,
})


PEFT_TYPE_TO_CONFIG_MAPPING.update({
    "LORA": MoLoraConfig,
    })


def _get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name


class MoPeftModel(PeftModel):
    """
    Base model encompassing various Peft methods. Modified from PeftModel to support 
    multiple sets of LoRA parameters.

    Args:
        model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft.
        peft_config ([`PeftConfig`]): The configuration of the Peft model.


    **Attributes**:
        - **base_model** ([`~transformers.PreTrainedModel`]) -- The base transformer model used for Peft.
        - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
        - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
        saving the model.
        - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
        using [`PromptLearningConfig`].
        - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
        using [`PromptLearningConfig`].
        - **transformer_backbone_name** (`str`) -- The name of the transformer
        backbone in the base model if using [`PromptLearningConfig`].
        - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
        in the base model if using [`PromptLearningConfig`].
    """
    def __init__(
        self, 
        model, 
        peft_config: Union[PeftConfig, List[PeftConfig]], 
        adapter_name: Union[str, List[str]] = "default"
        ):
        super(PeftModel, self).__init__()  # initialize PushToHubMixin, torch.nn.Module
        self.modules_to_save = None
        if isinstance(adapter_name, str):
            adapter_name = [adapter_name]
        assert (
            isinstance(adapter_name, list) 
            and all(isinstance(name, str) for name in adapter_name)
        ), TypeError(
            "Expect adapter_name to be str or List[str]; got {}".format(type(adapter_name)))
        self.active_adapter = adapter_name  # list
        if isinstance(peft_config, PeftConfig):
            peft_config = [peft_config] * len(adapter_name)
        assert (
            isinstance(peft_config, list) 
            and all(isinstance(config, PeftConfig) for config in peft_config)
        ), TypeError("Expect peft_config to be a list of PeftConfig; got {}".format(peft_config))
        self.peft_type = peft_config[0].peft_type  # all peft_config should have same type

        self._is_prompt_learning = peft_config[0].is_prompt_learning
        if self._is_prompt_learning:
            self._peft_config = dict(zip(adapter_name, peft_config))
            self.base_model = model
            self.add_adapter(adapter_name, peft_config)
        else:
            self._peft_config = None  # use self.base_model.peft_config instead
            self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[self.peft_type](
                model, dict(zip(adapter_name, peft_config)), adapter_name
            )
            self.set_additional_trainable_modules(peft_config, adapter_name)
        
        self.config = getattr(self.base_model, "config", {"model_type": "custom"})
        if getattr(model, "is_gradient_checkpointing", True):
            model = self._prepare_model_for_gradient_checkpointing(model)

        # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
        # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
        # behavior we disable that in this line.
        if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
            self.base_model.config.pretraining_tp = 1

    @property
    def peft_config(self) -> Dict[str, PeftConfig]:
        if self._is_prompt_learning:
            return self._peft_config
        return self.base_model.peft_config

    @peft_config.setter
    def peft_config(self, value: Dict[str, PeftConfig]):
        if self._is_prompt_learning:
            self._peft_config = value
        else:
            self.base_model.peft_config = value
        
    @property
    def active_adapters(self):
        try:
            adapters = self.base_model.active_adapters
        except AttributeError:
            adapters = self.active_adapter
            if isinstance(adapters, str):
                adapters = [adapters]
        return adapters

    @classmethod
    def from_pretrained(
        cls, 
        model: PreTrainedModel, 
        model_id: Union[str, os.PathLike],
        adapter_name: str = "default",
        is_trainable: bool = False,
        config: Optional[Union[PeftConfig, List[PeftConfig]]] = None,
        **kwargs
    ):
        r"""
        Instantiate a [`LoraModel`] from a pretrained Lora configuration and weights.

        Args:
            model ([`~transformers.PreTrainedModel`]):
                The model to be adapted. The model should be initialized with the
                [`~transformers.PreTrainedModel.from_pretrained`] method from the 🤗 Transformers library.
            model_id (`str` or `os.PathLike`):
                The name of the PEFT configuration to use. Can be either:
                    - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face
                      Hub.
                    - A path to a directory containing a PEFT configuration file saved using the `save_pretrained`
                      method (`./my_peft_config_directory/`).
            adapter_name (`str`, *optional*, defaults to `"default"`):
                The name of the adapter to be loaded. This is useful for loading multiple adapters.
            is_trainable (`bool`, *optional*, defaults to `False`):
                Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for
                inference
            config ([`~peft.PeftConfig`], *optional*):
                The configuration object to use instead of an automatically loaded configuation. This configuration
                object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
                loaded before calling `from_pretrained`.
            kwargs: (`optional`):
                Additional keyword arguments passed along to the specific PEFT configuration class.
                subfolder (`str`):
                    Provide it if adapter_name is not "default" and model_id does not contain adapter_name.
        """
        from peft.mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING

        if isinstance(model_id, str):
            model_id = {adapter_name: model_id}
        adapter_name = list(model_id.keys())

        # load the config
        if config is None:
            first_model_id = model_id[adapter_name[0]]
            config_class = PEFT_TYPE_TO_CONFIG_MAPPING[
                PeftConfig._get_peft_type(
                    first_model_id, 
                    subfolder=kwargs.get("subfolder", None),
                    revision=kwargs.get("revision", None),
                    cache_dir=kwargs.get("cache_dir", None),
                    use_auth_token=kwargs.get("use_auth_token", None),
                )
            ]
            config = [config_class.from_pretrained(model_id[_adapter_name], **kwargs) for _adapter_name in adapter_name]
        elif isinstance(config, list):
            assert all(isinstance(_config, PeftConfig) for _config in config), TypeError(
                "Expect config to be a list of PeftConfig; got {}".format(config))
        else:
            assert isinstance(config, PeftConfig), ValueError(
                f"The input config must be a PeftConfig or List[PeftConfig], got {config.__class__}")

        if (getattr(model, "hf_device_map", None) is not None) and len(
            set(model.hf_device_map.values()).intersection({"cpu", "disk"})
        ) > 0:
            remove_hook_from_submodules(model)

        if is_trainable and any(_config.is_prompt_learning for _config in config):
            raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
        else:
            for _config in config:
                _config.inference_mode = not is_trainable

        if len(config) > 1:
            task_type = set([_config.task_type for _config in config])
            assert len(task_type) == 1, ValueError("Multiple configs have different task_type {}".format(task_type))
            task_type = list(task_type)[0]
        else:
            task_type = config[0].task_type

        if task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
            model = cls(model, config, adapter_name)
        else:
            model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type](model, config, adapter_name)
        model.load_adapter(model_id, adapter_name, **kwargs)
        return model

    def add_adapter(self, adapter_name: Union[str, List[str]], peft_config: PeftConfig):
        if isinstance(adapter_name, str):
            adapter_name = [adapter_name]
        if isinstance(peft_config, dict):
            peft_config = [peft_config]

        if any(_peft_config.peft_type != self.peft_type for _peft_config in peft_config):
            unexpected_types = [_peft_config.peft_type for _peft_config in peft_config if _peft_config.peft_type != self.peft_type]
            raise ValueError(
                f"Cannot combine adapters with different peft types. "
                f"Found {self.peft_type} and {unexpected_types}."
            )
        if any(_peft_config.is_prompt_learning for _peft_config in peft_config):
            self._peft_config = dict(zip(adapter_name, peft_config))
            dict_config = self.config.to_dict() if hasattr(self.config, "to_dict") else self.config
            peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
            # TODO how to handle multiple PromptLearning?
            self._setup_prompt_encoder(adapter_name)
        elif any(_peft_config.is_adaption_prompt for _peft_config in peft_config):
            self.base_model.add_adapter(adapter_name, peft_config)
        else:
            self.peft_config.update(dict(zip(adapter_name, peft_config)))
            self.base_model.add_adapter(adapter_name, peft_config)

        self.set_additional_trainable_modules(peft_config, adapter_name)

    def _set_trainable(self, adapter_name: str):
        key_list = [key for key, _ in self.named_modules()]
        modules_to_save = self.modules_to_save[adapter_name] if isinstance(self.modules_to_save, dict) else self.modules_to_save
        for key in key_list:
            target_module_found = any(key.endswith(target_key) for target_key in modules_to_save)
            if target_module_found:
                parent, target, target_name = _get_submodules(self, key)
                if isinstance(target, ModulesToSaveWrapper):
                    target.update(adapter_name)
                    target.set_adapter(target.active_adapter)
                else:
                    new_module = ModulesToSaveWrapper(target, adapter_name)
                    new_module.set_adapter(adapter_name)
                    setattr(parent, target_name, new_module)

    def set_additional_trainable_modules(self, peft_config, adapter_name):
        if isinstance(adapter_name, str):
            if getattr(peft_config, "modules_to_save", None) is not None:
                if self.modules_to_save is None:
                    self.modules_to_save = set(peft_config.modules_to_save)
                else:
                    self.modules_to_save.update(peft_config.modules_to_save)
                self._set_trainable(adapter_name)
        elif isinstance(adapter_name, list):
            for one_adapter_name, one_peft_config in zip(adapter_name, peft_config):
                if getattr(one_peft_config, "modules_to_save", None) is not None:
                    if self.modules_to_save is None:
                        self.modules_to_save = {one_adapter_name: set(one_peft_config.modules_to_save)}
                    else:
                        self.modules_to_save[one_adapter_name].update(one_peft_config.modules_to_save)
                    self._set_trainable(one_adapter_name)

    def load_adapter(
        self, 
        model_id: Union[str, dict], 
        adapter_name: Union[str, List[str]], 
        is_trainable: bool = False, 
        **kwargs
    ):
        hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)
        torch_device = infer_device()

        subfolder = kwargs.get("subfolder", None)
        if isinstance(model_id, str):
            model_id = {(adapter_name[0] if isinstance(adapter_name, list) else adapter_name): model_id}
        adapter_name = list(model_id.keys())
        
        if any(_adapter_name not in self.peft_config for _adapter_name in adapter_name):
            first_model_id = model_id[adapter_name[0]]
            # load the config
            config_class = PEFT_TYPE_TO_CONFIG_MAPPING[
                PeftConfig._get_peft_type(
                    first_model_id, 
                    **hf_hub_download_kwargs,
                )
            ]
            peft_config = [
                config_class.from_pretrained(
                    model_id[_adapter_name], **hf_hub_download_kwargs
                ) for _adapter_name in adapter_name
            ]

            if is_trainable and any(_config.is_prompt_learning for _config in peft_config):
                raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
            else:
                for _config in peft_config:
                    _config.inference_mode = not is_trainable
            self.add_adapter(adapter_name, peft_config)

        # load weights if any
        adapters_weights = [
            load_peft_weights(
                model_id[_adapter_name], device=torch_device, **hf_hub_download_kwargs
            ) for _adapter_name in adapter_name
        ]
        
        # load the weights into the model
        load_result = []
        for _adapters_weights, _adapter_name in zip(adapters_weights, adapter_name):
            load_result.append(set_peft_model_state_dict(self, _adapters_weights, adapter_name=_adapter_name))
        if (
            (getattr(self, "hf_device_map", None) is not None)
            and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
            and len(self.peft_config) == 1
        ):
            device_map = kwargs.get("device_map", "auto")
            max_memory = kwargs.get("max_memory", None)
            offload_dir = kwargs.get("offload_folder", None)
            offload_index = kwargs.get("offload_index", None)

            dispatch_model_kwargs = {}
            # Safety checker for previous `accelerate` versions
            # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
            if "offload_index" in inspect.signature(dispatch_model).parameters:
                dispatch_model_kwargs["offload_index"] = offload_index

            no_split_module_classes = self._no_split_modules

            if device_map != "sequential":
                max_memory = get_balanced_memory(
                    self,
                    max_memory=max_memory,
                    no_split_module_classes=no_split_module_classes,
                    low_zero=(device_map == "balanced_low_0"),
                )
            if isinstance(device_map, str):
                device_map = infer_auto_device_map(
                    self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
                )
            dispatch_model(
                self,
                device_map=device_map,
                offload_dir=offload_dir,
                **dispatch_model_kwargs,
            )
            hook = AlignDevicesHook(io_same_device=True)
            if any(isinstance(self.peft_config[_adapter_name], PromptLearningConfig) for _adapter_name in adapter_name):
                remove_hook_from_submodules(self.prompt_encoder)
            add_hook_to_module(self.get_base_model(), hook)

        # Set model in evaluation mode to deactivate Dropout modules by default
        if not is_trainable:
            self.eval()
        return load_result

    def get_base_model(self):
        """
        Returns the base model.
        """
        is_prompt_learning = self.active_peft_config[0].is_prompt_learning \
            if isinstance(self.active_peft_config, list) else self.active_peft_config.is_prompt_learning
        return self.base_model if is_prompt_learning else self.base_model.model

    def set_adapter(self, adapter_name: Union[str, List[str]]):
        """
        Sets the active adapter.
        """
        if isinstance(adapter_name, str):
            adapter_name = [adapter_name]
        if any(_adapter_name not in self.peft_config for _adapter_name in adapter_name):
            raise ValueError(f"Adapter {adapter_name} not found.")
        self.active_adapter = adapter_name
        if any(not self.peft_config[_adapter_name].is_prompt_learning for _adapter_name in adapter_name):
            self.base_model.set_adapter(adapter_name)
        _set_adapter(self, adapter_name)

    @property
    def base_model_torch_dtype(self):
        return getattr(self.base_model, "dtype", None)

    @property
    def active_peft_config(self):
        if isinstance(self.active_adapter, list):
            return [self.peft_config[adapter_name] for adapter_name in self.active_adapter]
        else:
            return self.peft_config[self.active_adapter]
        
    @contextmanager
    def disable_adapter(self):
        """
        Disables the adapter module.
        """
        try:
            if any(self.peft_config[adapter_name].is_prompt_learning for adapter_name in self.active_adapter):
                # TODO: consider replacing this patching of methods with a more robust mechanism: setting a flag and
                # letting the underyling methods deal with it, same as how LoRA does it.
                old_forward = self.forward
                self.forward = self.base_model.forward
                old_prepare_inputs_for_generation = self.prepare_inputs_for_generation
                self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
            else:
                self.base_model.disable_adapter_layers()
            yield
        finally:
            if any(self.peft_config[adapter_name].is_prompt_learning for adapter_name in self.active_adapter):
                self.forward = old_forward
                self.old_prepare_inputs_for_generation = old_prepare_inputs_for_generation
            else:
                self.base_model.enable_adapter_layers()


class MoPeftModelForCausalLM(MoPeftModel, PeftModelForCausalLM):
    """
    Peft model for causal language modeling, based on MoPeftModel instead.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.
    """
    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        MoPeftModel.__init__(self, model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        task_ids=None,
        **kwargs,
    ):
        peft_config = self.active_peft_config
        is_prompt_learning = peft_config[0].is_prompt_learning \
            if isinstance(peft_config, list) else peft_config.is_prompt_learning
        peft_type = peft_config[0].peft_type if isinstance(peft_config, list) else peft_config.peft_type

        if not is_prompt_learning:
            if self.base_model.config.model_type == "mpt":
                if inputs_embeds is not None:
                    raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = _get_batch_size(input_ids, inputs_embeds)
        num_virtual_tokens = peft_config[0].num_virtual_tokens \
            if isinstance(peft_config, list) else peft_config.num_virtual_tokens
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
            )
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            # concat prompt labels
            if labels is not None:
                prefix_labels = torch.full((batch_size, num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)


MODEL_TYPE_TO_PEFT_MODEL_MAPPING.update({
    "CAUSAL_LM": MoPeftModelForCausalLM,
})


def load_safetensors(file_path, device="cpu"):
    from safetensors import safe_open

    tensors = None
    if os.path.isfile(file_path) and file_path.endswith(".safetensors"):
        with safe_open(file_path, framework="pt", device=device) as f:
            tensors = {k: f.get_tensor(k) for k in f.keys()}
    else:
        raise FileNotFoundError("File path is not a valid safetensor path: {}".format(file_path))
    
    return tensors


def load_peft_weights_into_model(model, ckpt_folder, adapter_name="default"):
    """ Load peft model checkpoint into model using default peft model classes.
    """
    ckpt_path1 = os.path.join(ckpt_folder, "adapter_model.safetensors")
    ckpt_path2 = os.path.join(ckpt_folder, "pytorch_model.safetensors")
    ckpt_path3 = os.path.join(ckpt_folder, "adapter_model.bin")
    ckpt_path4 = os.path.join(ckpt_folder, "pytorch_model.bin")
    
    adapters_weights = None
    valid_ckpt_path = None
    if os.path.isfile(ckpt_path1):
        adapters_weights = load_safetensors(ckpt_path1, device="cpu")
        assert adapters_weights, RuntimeError("CKPT File is empty: {}".format(ckpt_path1))
        valid_ckpt_path = ckpt_path1
    elif os.path.isfile(ckpt_path2):
        adapters_weights = load_safetensors(ckpt_path2, device="cpu")
        assert adapters_weights, RuntimeError("CKPT File is empty: {}".format(ckpt_path2))
        valid_ckpt_path = ckpt_path2
    elif os.path.isfile(ckpt_path3):
        adapters_weights = torch.load(ckpt_path3)
        assert adapters_weights, RuntimeError("CKPT File is empty: {}".format(ckpt_path3))
        valid_ckpt_path = ckpt_path3
    elif os.path.isfile(ckpt_path4):
        adapters_weights = torch.load(ckpt_path4)
        assert adapters_weights, RuntimeError("CKPT File is empty: {}".format(ckpt_path4))
        valid_ckpt_path = ckpt_path4
    else:
        raise FileNotFoundError(f"Cannot find a checkpoint in {ckpt_folder}.")
    
    if adapters_weights is not None:
        set_peft_model_state_dict(model, adapters_weights, adapter_name=adapter_name)

    return model, valid_ckpt_path