# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import warnings
from contextlib import contextmanager
import sys

from typing import Any, Literal, Optional, Union
import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin
from peft.utils.other import set_additional_trainable_modules
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer

from .tuners import LoraModel, PrefixEncoder, PromptEmbedding, PromptEncoder
from .utils import (
    TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
    WEIGHTS_NAME,
    PeftConfig,
    PeftType,
    PromptLearningConfig,
    TaskType,
    _set_trainable,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    shift_tokens_right,
    infer_device,
    load_peft_weights,
    _prepare_prompt_learning_config,
)
PEFT_TYPE_TO_PREFIX_MAPPING: dict[PeftType, str] = {}

class PeftModel(PushToHubMixin, torch.nn.Module):
    """
    Parameter-Efficient Fine-Tuning Model. Base model encompassing various Peft methods.

    Args:
        model ([`PreTrainedModel`]): The base transformer model used for Peft.
        peft_config ([`PeftConfig`]): The configuration of the Peft model.


    **Attributes**:
        - **base_model** ([`PreTrainedModel`]) -- The base transformer model used for Peft.
        - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
        - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
        saving the model.
        - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
        `isinstance(self.peft_config, PromptLearningConfig)`.
        - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
        `isinstance(self.peft_config, PromptLearningConfig)`.
        - **transformer_backbone_name** (`str`) -- The name of the transformer
        backbone in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
        - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
        in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
    """

    def __init__(self, model, peft_config: PeftConfig): # casualLM, LoraConfig
        super().__init__()
        self.peft_config = peft_config
        self.base_model = model
        self.config = self.base_model.config
        self.modules_to_save = None
        self.peft_type = peft_config.peft_type
        if isinstance(self.peft_config, PromptLearningConfig):
            self._setup_prompt_encoder()
        else: # --------------> here
            self.base_model = LoraModel(peft_config, model)
        if getattr(self.peft_config, "modules_to_save", None) is not None:
            self.modules_to_save = self.peft_config.modules_to_save
            _set_trainable(self)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def save_pretrained(self, save_directory, **kwargs):
        r"""
        Args:
        This function saves the adapter model and the adapter configuration files to a directory, so that it can be
        re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub`
        method.
            save_directory (`str`):
                Directory where the adapter model and configuration files will be saved (will be created if it does not
                exist).
            **kwargs:
                Additional keyword arguments passed along to the `push_to_hub` method.
        """
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        os.makedirs(save_directory, exist_ok=True)

        # save only the trainable weights
        output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
        torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))

        # save the config and change the inference mode to `True`
        if self.peft_config.base_model_name_or_path is None:
            self.peft_config.base_model_name_or_path = (
                self.base_model.__dict__.get("name_or_path", None)
                if isinstance(self.peft_config, PromptLearningConfig)
                else self.base_model.model.__dict__.get("name_or_path", None)
            )
        inference_mode = self.peft_config.inference_mode
        self.peft_config.inference_mode = True
        self.peft_config.save_pretrained(save_directory)
        self.peft_config.inference_mode = inference_mode

    @classmethod
    def from_pretrained(cls, model, model_id_or_dict, lora_config=None, **kwargs):
        r"""
        Args:
        Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
            model (`transformers.PreTrainedModel`):
                The model to be adapted. The model should be initialized with the `from_pretrained` method. from
                `transformers` library.
            model_id_or_dict (`str` or `dict`):
                The name of the Lora configuration to use. Can be either:
                    - A string, the `model id` of a Lora configuration hosted inside a model repo on
                        huggingface Hub
                    - A path to a directory containing a Lora configuration file saved using the
                        `save_pretrained` method, e.g., ``./my_lora_config_directory/``.
        """
        from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING

        if isinstance(model_id_or_dict, dict):
            print("model_id_or_dict is a dictionary")   
            if "route" not in model_id_or_dict.keys():
                warnings.warn(
                    "The model_id_or_dict is a dictionary, but does not contain the key 'route'. "
                    "This may lead to unexpected behavior."
                )
            # @TODO: Write the function to load the route weight to the model

            target = kwargs.get("task", None)

            for i, (task, model_id) in enumerate(model_id_or_dict.items()):
                if i == 0:
                    if lora_config is not None:
                        config = lora_config
                    else:
                        config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
                        config.lora_nums = len(model_id_or_dict)
                        warnings.warn(
                            f"The lora config is not provided, using the default config from {model_id}. "
                        )

                    assert config.lora_nums == len(model_id_or_dict), f"lora_nums should be {len(model_id_or_dict)} but got {config.lora_nums}"

                    if getattr(model, "hf_device_map", None) is not None:
                        remove_hook_from_submodules(model)

                    # avoid re-initializing the model
                    if not isinstance(model, PeftModel):
                        if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
                            model = cls(model, config)
                        else:
                            model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)
                    
                if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
                    filename = os.path.join(model_id, WEIGHTS_NAME)
                else:
                    try:
                        filename = hf_hub_download(model_id, WEIGHTS_NAME)
                    except:
                        raise ValueError(
                            f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
                            f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
                        )

                adapters_weights = torch.load(
                    filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
                )
                # load the weights into the model
                model = set_peft_model_state_dict(
                                            model,
                                            adapters_weights,
                                            lora_id=i,
                                            use_all=True if (target==task) else False,
                                            task=task,
                                        )
        
        elif isinstance(model_id_or_dict, str):
            print("model_id_or_dict is a string")
            model_id = model_id_or_dict
            # load the config
            if lora_config is not None:
                config = lora_config
            else:
                config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
                warnings.warn(
                    f"The lora config is not provided, using the default config from {model_id}. "
                )
            if getattr(model, "hf_device_map", None) is not None:
                remove_hook_from_submodules(model)

            # avoid re-initializing the model
            if not isinstance(model, PeftModel):
                if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
                    model = cls(model, config)
                else:
                    model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)
                
            if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
                filename = os.path.join(model_id, WEIGHTS_NAME)
            else:
                try:
                    filename = hf_hub_download(model_id, WEIGHTS_NAME)
                except:
                    raise ValueError(
                        f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
                        f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
                    )

            adapters_weights = torch.load(
                filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
            )
            task = kwargs.get("task", None)
            # load the weights into the model
            model = set_peft_model_state_dict(model, adapters_weights, task=task)
        else:
            raise ValueError(
                f"model_id_or_dict should be a string or a dict, but got {type(model_id_or_dict)}"
            )

        if getattr(model, "hf_device_map", None) is not None:
            # Describes how layers are distributed across devices (e.g., GPU 0, 1, etc.).
            device_map = kwargs.get("device_map", "auto")
            # Module types that shouldn't be split across devices (e.g., attention blocks).
            max_memory = kwargs.get("max_memory", None)
            no_split_module_classes = model._no_split_modules
            if device_map != "sequential":
                max_memory = get_balanced_memory(
                    model,
                    max_memory=max_memory,
                    no_split_module_classes=no_split_module_classes,
                    low_zero=(device_map == "balanced_low_0"),
                )
            if isinstance(device_map, str):
                device_map = infer_auto_device_map(
                    model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
                )
            model = dispatch_model(model, device_map=device_map)
            hook = AlignDevicesHook(io_same_device=True)
            if model.peft_config.peft_type == PeftType.LORA:
                add_hook_to_module(model.base_model.model, hook)
            else:
                remove_hook_from_submodules(model.prompt_encoder)
                add_hook_to_module(model.base_model, hook)

        return model

    def _setup_prompt_encoder(self):
        transformer_backbone = None
        for name, module in self.base_model.named_children():
            for param in module.parameters():
                param.requires_grad = False
            if isinstance(module, PreTrainedModel):
                # Make sure to freeze Tranformers model
                if transformer_backbone is None:
                    transformer_backbone = module
                    self.transformer_backbone_name = name

        if self.peft_config.num_transformer_submodules is None:
            self.peft_config.num_transformer_submodules = (
                2 if self.peft_config.task_type == TaskType.SEQ_2_SEQ_LM else 1
            )

        for named_param, value in list(transformer_backbone.named_parameters()):
            if value.shape[0] == self.base_model.config.vocab_size:
                self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
                break

        if self.peft_config.peft_type == PeftType.PROMPT_TUNING:
            prompt_encoder = PromptEmbedding(self.peft_config, self.word_embeddings)
        elif self.peft_config.peft_type == PeftType.P_TUNING:
            prompt_encoder = PromptEncoder(self.peft_config)
        elif self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            prompt_encoder = PrefixEncoder(self.peft_config)
        else:
            raise ValueError("Not supported")
        self.prompt_encoder = prompt_encoder
        self.prompt_tokens = torch.arange(
            self.peft_config.num_virtual_tokens * self.peft_config.num_transformer_submodules
        ).long()

    def get_prompt_embedding_to_save(self):
        """
        Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type !=
        PeftType.LORA`.
        """
        prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(1, -1).to(self.device)
        if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
        prompt_embeddings = self.prompt_encoder(prompt_tokens)

        return prompt_embeddings[0].detach().cpu()

    def get_prompt(self, batch_size):
        """
        Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`.
        """
        prompt_tokens = self.prompt_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
        if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            prompt_tokens = prompt_tokens[:, : self.peft_config.num_virtual_tokens]
            if self.peft_config.inference_mode:
                past_key_values = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                past_key_values = self.prompt_encoder(prompt_tokens)
            past_key_values = past_key_values.view(
                batch_size,
                self.peft_config.num_virtual_tokens,
                self.peft_config.num_layers * 2,
                self.peft_config.num_attention_heads,
                self.peft_config.token_dim // self.peft_config.num_attention_heads,
            )
            if self.peft_config.num_transformer_submodules == 2:
                past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
            past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
                self.peft_config.num_transformer_submodules * 2
            )
            if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
                post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
                past_key_values = post_process_fn(past_key_values)
            return past_key_values
        else:
            if self.peft_config.inference_mode:
                prompts = self.prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                prompts = self.prompt_encoder(prompt_tokens)

            return prompts

    def print_trainable_parameters(self, verbose=False):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for name, param in self.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            all_param += num_params
            if param.requires_grad:
                if verbose:
                    print(f'trainable, name: {name}, params shape: {param.data.shape}')
                trainable_params += num_params
            else:
                if verbose:
                    print(f'freezing, name: {name}, params shape: {param.data.shape}')                    
        print(
            f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
        )

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.base_model, name)

    def forward(self, *args, **kwargs):  # pylint: disable=E0202
        """
        Forward pass of the model.
        """
        return self.get_base_model()(*args, **kwargs)

    @contextmanager
    def disable_adapter(self):
        """
        Disables the adapter module.
        """
        if isinstance(self.peft_config, PromptLearningConfig):
            old_forward = self.forward
            self.forward = self.base_model.forward
        else:
            self.base_model.disable_adapter_layers()
        yield
        if isinstance(self.peft_config, PromptLearningConfig):
            self.forward = old_forward
        else:
            self.base_model.enable_adapter_layers()

    def get_base_model(self):
        """
        Returns the base model.
        """
        return self.base_model if isinstance(self.peft_config, PromptLearningConfig) else self.base_model.model
    
    def unfreeze_sparsegen_for_training(self):
        """Unfreeze sparsegen parameters for training"""
        if hasattr(self.base_model, 'unfreeze_sparsegen_for_training'):
            self.base_model.unfreeze_sparsegen_for_training()
    
    def freeze_sparsegen_for_eval(self):
        """Freeze sparsegen parameters for evaluation"""
        if hasattr(self.base_model, 'freeze_sparsegen_for_eval'):
            self.base_model.freeze_sparsegen_for_eval()
    
    @classmethod
    def from_pretrained_preserve_sparsegen(cls, model, model_id, adapter_name="default", is_trainable=False, config=None, **kwargs):
        """Load PEFT model while preserving sparsegen parameter states from warmup training.
        
        This method prevents sparsegen parameters from being reinitialized when loading
        a model that has been partially trained during warmup phase.
        """
        # Set flag to skip sparsegen initialization
        if hasattr(model, 'named_modules'):
            for name, module in model.named_modules():
                if hasattr(module, 'sparsegen') and module.sparsegen:
                    module._skip_sparsegen_init = True
        
        # Load normally
        return cls.from_pretrained(model, model_id, adapter_name, is_trainable, config, **kwargs)
    
    def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
        """
        Add an adapter to the model based on the passed configuration.

        This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].

        The name for the new adapter should be unique.

        The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
        adapter.

        Args:
            adapter_name (`str`):
                The name of the adapter to be added.
            peft_config ([`PeftConfig`]):
                The configuration of the adapter to be added.
            low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
                Create empty adapter weights on meta device. Useful to speed up the process when loading saved
                adapters. Don't use this option when creating a new PEFT adapter for training.

        """
        prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type)
        if prefix and adapter_name in prefix:
            warnings.warn(
                f"Adapter name {adapter_name} should not be contained in the prefix {prefix}."
                "This may lead to reinitialization of the adapter weights during loading."
            )

        if peft_config.peft_type != self.peft_type:
            raise ValueError(
                f"Cannot combine adapters with different peft types. "
                f"Found {self.peft_type} and {peft_config.peft_type}."
            )

        try:
            if peft_config.is_prompt_learning:
                self.peft_config[adapter_name] = peft_config
                if hasattr(self.config, "to_dict"):
                    dict_config = self.config.to_dict()
                else:
                    dict_config = self.config

                peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
                self._setup_prompt_encoder(adapter_name)
                set_additional_trainable_modules(
                    model=self.base_model,
                    peft_config=peft_config,
                    model_config=BaseTuner.get_model_config(self),
                    adapter_name=adapter_name,
                )
            elif peft_config.is_adaption_prompt:
                self.base_model.add_adapter(adapter_name, peft_config)
                set_additional_trainable_modules(
                    model=self.base_model,
                    peft_config=peft_config,
                    model_config=BaseTuner.get_model_config(self),
                    adapter_name=adapter_name,
                )
            else:
                self.peft_config[adapter_name] = peft_config
                self.base_model.inject_adapter(
                    self.base_model.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage
                )
        except Exception:  # something went wrong, roll back
            if adapter_name in self.peft_config:
                del self.peft_config[adapter_name]
            raise

    def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None:
            """Perform checks on newly added PEFT configs to ensure integrity."""
            if peft_config.is_prompt_learning and is_trainable:
                raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")

            # Since PiSSA/CorDA/OLoRA modifies the base weights, it should not be combined with other adapters.
            all_configs = [peft_config] + list(self.peft_config.values())
            if len(all_configs) > 1:
                if any(getattr(config, "init_lora_weights", None) == "pissa" for config in all_configs):
                    msg = (
                        "PiSSA changes the base weights of the model and should thus not be used with other adapters. "
                        "Consider converting the PiSSA adapter into a normal LoRA adapter: "
                        "https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning#convert-pissa-to-lora"
                    )
                    warnings.warn(msg)
                elif any(getattr(config, "init_lora_weights", None) == "corda" for config in all_configs):
                    msg = (
                        "CorDA changes the base weights of the model and should thus not be used with other adapters. "
                        "Consider converting the CorDA adapter into a normal LoRA adapter: "
                        "https://github.com/huggingface/peft/tree/main/examples/corda_finetuning#convert-corda-to-lora"
                    )
                    warnings.warn(msg)
                elif any(getattr(config, "init_lora_weights", None) == "olora" for config in all_configs):
                    msg = (
                        "OLoRA changes the base weights of the model and should thus not be used with other adapters. "
                        "Consider converting the OLoRA adapter into a normal LoRA adapter: "
                        "https://github.com/huggingface/peft/tree/main/examples/olora_finetuning#olora-and-lora"
                    )
                    warnings.warn(msg)

    def load_adapter(
            self,
            model_id: Union[str, os.PathLike],
            adapter_name: str,
            is_trainable: bool = False,
            torch_device: Optional[str] = None,
            autocast_adapter_dtype: bool = True,
            ephemeral_gpu_offload: bool = False,
            low_cpu_mem_usage: bool = False,
            **kwargs: Any,
        ):
            """
            Load a trained adapter into the model.

            The name for the new adapter should be unique.

            The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
            adapter.

            Args:
                model_id (`str` or `os.PathLike`):
                    The name of the PEFT configuration to use. Can be either:
                        - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face
                        Hub.
                        - A path to a directory containing a PEFT configuration file saved using the `save_pretrained`
                        method (`./my_peft_config_directory/`).
                adapter_name (`str`):
                    The name of the adapter to be added.
                is_trainable (`bool`, *optional*, defaults to `False`):
                    Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
                    used for inference.
                torch_device (`str`, *optional*, defaults to None):
                    The device to load the adapter on. If `None`, the device will be inferred.
                autocast_adapter_dtype (`bool`, *optional*, defaults to `True`):
                    Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter
                    weights using float16 and bfloat16 to float32, as this is typically required for stable training, and
                    only affect select PEFT tuners.
                ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`):
                    Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`.
                low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
                    Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the
                    process.
                kwargs: (`optional`):
                    Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
            """
            from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING

            hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)
            if torch_device is None:
                torch_device = infer_device()

            if adapter_name not in self.peft_config:
                # load the config
                peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[
                    PeftConfig._get_peft_type(
                        model_id,
                        **hf_hub_download_kwargs,
                    )
                ].from_pretrained(
                    model_id,
                    # ephemeral_gpu_offload=ephemeral_gpu_offload,
                    **hf_hub_download_kwargs,
                )
                self._check_new_adapter_config(peft_config, is_trainable=is_trainable)
                peft_config.inference_mode = not is_trainable
                self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)

            adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)

            # load the weights into the model
            ignore_mismatched_sizes = kwargs.get("ignore_mismatched_sizes", False)
            load_result = set_peft_model_state_dict(
                self,
                adapters_weights,
                adapter_name=adapter_name,
                ignore_mismatched_sizes=ignore_mismatched_sizes,
                low_cpu_mem_usage=low_cpu_mem_usage,
            )

            tuner = self.peft_config[adapter_name].peft_type
            tuner_prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(tuner, "")
            adapter_missing_keys = []

            # Filter missing keys specific to the current adapter and tuner prefix.
            for key in load_result.missing_keys:
                if tuner_prefix in key and adapter_name in key:
                    adapter_missing_keys.append(key)

            load_result.missing_keys.clear()
            load_result.missing_keys.extend(adapter_missing_keys)

            if (
                (getattr(self, "hf_device_map", None) is not None)
                and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
                and len(self.peft_config) == 1
            ):
                device_map = kwargs.get("device_map", "auto")
                max_memory = kwargs.get("max_memory", None)
                offload_dir = kwargs.get("offload_folder", None)
                offload_index = kwargs.get("offload_index", None)

                dispatch_model_kwargs = {}
                # Safety checker for previous `accelerate` versions
                # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
                if "offload_index" in inspect.signature(dispatch_model).parameters:
                    dispatch_model_kwargs["offload_index"] = offload_index

                no_split_module_classes = self._no_split_modules

                if device_map != "sequential":
                    max_memory = get_balanced_memory(
                        self,
                        max_memory=max_memory,
                        no_split_module_classes=no_split_module_classes,
                        low_zero=(device_map == "balanced_low_0"),
                    )

                if isinstance(device_map, str):
                    device_map = infer_auto_device_map(
                        self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
                    )

                self._update_offload(offload_index, adapters_weights)
                dispatch_model_kwargs["offload_index"] = offload_index

                dispatch_model(
                    self,
                    device_map=device_map,
                    offload_dir=offload_dir,
                    **dispatch_model_kwargs,
                )

                hook = AlignDevicesHook(io_same_device=True)
                if self.peft_config[adapter_name].is_prompt_learning:
                    remove_hook_from_submodules(self.prompt_encoder)
                add_hook_to_module(self.get_base_model(), hook)

            if hasattr(self.base_model, "_cast_adapter_dtype"):
                self.base_model._cast_adapter_dtype(
                    adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
                )

            # Set model in evaluation mode to deactivate Dropout modules by default
            if not is_trainable:
                self.eval()
            return load_result
    def _split_kwargs(cls, kwargs: dict[str, Any]):
        _kwargs_not_in_hf_hub_download_signature = ("use_auth_token",)
        hf_hub_download_kwargs = {}
        other_kwargs = {}

        for key, value in kwargs.items():
            if key in inspect.signature(hf_hub_download).parameters or key in _kwargs_not_in_hf_hub_download_signature:
                hf_hub_download_kwargs[key] = value
            else:
                other_kwargs[key] = value

        return hf_hub_download_kwargs, other_kwargs

class PeftModelForSequenceClassification(PeftModel):
    """
    Peft model for sequence classification tasks.

    Args:
        model ([`PreTrainedModel`]): Base transformer model
        peft_config ([`PeftConfig`]): Peft config.

    **Attributes**:
        - **config** ([`PretrainedConfig`]) -- The configuration object of the base model.
        - **cls_layer_name** (`str`) -- The name of the classification layer.

    Example::

        >>> from transformers import AutoModelForSequenceClassification >>> from peft import
        PeftModelForSequenceClassification, get_peft_config >>> config = {
                'peft_type': 'PREFIX_TUNING', 'task_type': 'SEQ_CLS', 'inference_mode': False, 'num_virtual_tokens':
                20, 'token_dim': 768, 'num_transformer_submodules': 1, 'num_attention_heads': 12, 'num_layers': 12,
                'encoder_hidden_size': 768, 'prefix_projection': False, 'postprocess_past_key_value_function': None
            }
        >>> peft_config = get_peft_config(config) >>> model =
        AutoModelForSequenceClassification.from_pretrained("bert-base-cased") >>> peft_model =
        PeftModelForSequenceClassification(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
        params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
    """

    def __init__(self, model, peft_config: PeftConfig):
        super().__init__(model, peft_config)
        self.modules_to_save = ["classifier", "score"]

        for name, _ in self.base_model.named_children():
            if any(module_name in name for module_name in self.modules_to_save):
                self.cls_layer_name = name
                break

        # to make sure classifier layer is trainable
        _set_trainable(self)

    def forward(    # pylint: disable=W0221
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if not isinstance(self.peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
        else:
            if kwargs.get("token_type_ids", None) is not None:
                kwargs["token_type_ids"] = torch.cat(
                    (
                        torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
                        kwargs["token_type_ids"],
                    ),
                    dim=1,
                ).long()
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def _prefix_tuning_forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size)
        fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
        kwargs.update(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
                "past_key_values": past_key_values,
            }
        )
        if "past_key_values" in fwd_params:
            return self.base_model(labels=labels, **kwargs)
        else:
            transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
            fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
            if "past_key_values" not in fwd_params:
                raise ValueError("Model does not support past key values which are required for prefix tuning.")
            outputs = transformer_backbone_name(**kwargs)
            pooled_output = outputs[1] if len(outputs) > 1 else outputs[0]
            if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
                pooled_output = self.base_model.dropout(pooled_output)
            logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output)

            loss = None
            if labels is not None:
                if self.config.problem_type is None:
                    if self.base_model.num_labels == 1:
                        self.config.problem_type = "regression"
                    elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                        self.config.problem_type = "single_label_classification"
                    else:
                        self.config.problem_type = "multi_label_classification"

                if self.config.problem_type == "regression":
                    loss_fct = MSELoss()
                    if self.base_model.num_labels == 1:
                        loss = loss_fct(logits.squeeze(), labels.squeeze())
                    else:
                        loss = loss_fct(logits, labels)
                elif self.config.problem_type == "single_label_classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1))
                elif self.config.problem_type == "multi_label_classification":
                    loss_fct = BCEWithLogitsLoss()
                    loss = loss_fct(logits, labels)
            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output

            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )


class PeftModelForCausalLM(PeftModel):
    """
    Peft model for Causal LM

    Args:
        model ([`PreTrainedModel`]): Base transformer model
        peft_config ([`PeftConfig`]): Peft config.


    Example::

        >>> from transformers import AutoModelForCausalLM >>> from peft import PeftModelForCausalLM, get_peft_config
        >>> config = {
                'peft_type': 'PREFIX_TUNING', 'task_type': 'CAUSAL_LM', 'inference_mode': False, 'num_virtual_tokens':
                20, 'token_dim': 1280, 'num_transformer_submodules': 1, 'num_attention_heads': 20, 'num_layers': 36,
                'encoder_hidden_size': 1280, 'prefix_projection': False, 'postprocess_past_key_value_function': None
            }
        >>> peft_config = get_peft_config(config) >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large") >>>
        peft_model = PeftModelForCausalLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
    """

    def __init__(self, model, peft_config: PeftConfig): # casualLM, LoraConfig
        super().__init__(model, peft_config)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

    def forward(# pylint: disable=W0221
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        task_types=None,
        **kwargs,
    ):
        if not isinstance(self.peft_config, PromptLearningConfig): # here
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                task_types=task_types,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            # concat prompt labels
            if labels is not None:
                prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def generate(self, **kwargs):
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        try:
            outputs = self.base_model.generate(**kwargs)
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            return outputs

    def prepare_inputs_for_generation(self, *args, **kwargs):
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if isinstance(self.peft_config, PromptLearningConfig):
            if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
                past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
                model_kwargs["past_key_values"] = past_key_values
            else:
                if model_kwargs["past_key_values"] is None:
                    inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                    prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
                    prompts = prompts.to(inputs_embeds.dtype)
                    model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
                    model_kwargs["input_ids"] = None

        return model_kwargs



class PeftModelForSeq2SeqLM(PeftModel):
    """
    Peft model for Seq2Seq LM

    Args:
        model ([`PreTrainedModel`]): Base transformer model
        peft_config ([`PeftConfig`]): Peft config.


    Example::

        >>> from transformers import AutoModelForSeq2SeqLM >>> from peft import PeftModelForSeq2SeqLM, get_peft_config
        >>> config = {
                'peft_type': 'LORA', 'task_type': 'SEQ_2_SEQ_LM', 'inference_mode': False, 'r': 8, 'target_modules':
                ['q', 'v'], 'lora_alpha': 32, 'lora_dropout': 0.1, 'merge_weights': False, 'fan_in_fan_out': False,
                'enable_lora': None, 'bias': 'none'
            }
        >>> peft_config = get_peft_config(config) >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>>
        peft_model = PeftModelForSeq2SeqLM(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
        params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
    """

    def __init__(self, model, peft_config: PeftConfig):
        super().__init__(model, peft_config)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        decoder_inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        if not isinstance(self.peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                decoder_inputs_embeds=decoder_inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if decoder_attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
            decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "decoder_attention_mask": decoder_attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
            )
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            if decoder_inputs_embeds is None and decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
                decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)

            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
            # concat prompt labels
            if labels is not None:
                if self.peft_config.num_transformer_submodules == 1:
                    kwargs["labels"] = labels
                elif self.peft_config.num_transformer_submodules == 2:
                    prefix_labels = torch.full((batch_size, self.peft_config.num_virtual_tokens), -100).to(self.device)
                    kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : self.peft_config.num_virtual_tokens], inputs_embeds), dim=1)
            if self.peft_config.num_transformer_submodules == 1:
                return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
            elif self.peft_config.num_transformer_submodules == 2:
                decoder_inputs_embeds = torch.cat(
                    (prompts[:, self.peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
                )
                return self.base_model(
                    inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
                )

    def generate(self, **kwargs):
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
            self._prepare_encoder_decoder_kwargs_for_generation
        )
        try:
            if not isinstance(self.peft_config, PromptLearningConfig):
                outputs = self.base_model.generate(**kwargs)
            else:
                if "input_ids" not in kwargs:
                    raise ValueError("input_ids must be provided for Peft model generation")
                if kwargs.get("position_ids", None) is not None:
                    warnings.warn(
                        "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
                    )
                    kwargs["position_ids"] = None
                if kwargs.get("token_type_ids", None) is not None:
                    warnings.warn(
                        "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
                    )
                    kwargs["token_type_ids"] = None

                if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
                    outputs = self.base_model.generate(**kwargs)
                else:
                    raise NotImplementedError
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            return outputs

    def prepare_inputs_for_generation(self, *args, **kwargs):
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if model_kwargs["past_key_values"] is None and self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            batch_size = model_kwargs["decoder_input_ids"].shape[0]
            past_key_values = self.get_prompt(batch_size)
            model_kwargs["past_key_values"] = past_key_values
        return model_kwargs


class PeftModelForTokenClassification(PeftModel):
    """
    Peft model for sequence classification tasks.

    Args:
        model ([`PreTrainedModel`]): Base transformer model
        peft_config ([`PeftConfig`]): Peft config.

    **Attributes**:
        - **config** ([`PretrainedConfig`]) -- The configuration object of the base model.
        - **cls_layer_name** (`str`) -- The name of the classification layer.

    Example::

        >>> from transformers import AutoModelForSequenceClassification >>> from peft import
        PeftModelForTokenClassification, get_peft_config >>> config = {
                'peft_type': 'PREFIX_TUNING', 'task_type': 'TOKEN_CLS', 'inference_mode': False, 'num_virtual_tokens':
                20, 'token_dim': 768, 'num_transformer_submodules': 1, 'num_attention_heads': 12, 'num_layers': 12,
                'encoder_hidden_size': 768, 'prefix_projection': False, 'postprocess_past_key_value_function': None
            }
        >>> peft_config = get_peft_config(config) >>> model =
        AutoModelForTokenClassification.from_pretrained("bert-base-cased") >>> peft_model =
        PeftModelForTokenClassification(model, peft_config) >>> peft_model.print_trainable_parameters() trainable
        params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
    """

    def __init__(self, model, peft_config: PeftConfig):
        super().__init__(model, peft_config)
        self.modules_to_save = ["classifier", "score"]

        for name, _ in self.base_model.named_children():
            if any(module_name in name for module_name in self.modules_to_save):
                self.cls_layer_name = name
                break

        # to make sure classifier layer is trainable
        _set_trainable(self)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if not isinstance(self.peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, self.peft_config.num_virtual_tokens).to(self.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if self.peft_config.peft_type == PeftType.PREFIX_TUNING:
            return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
        else:
            if kwargs.get("token_type_ids", None) is not None:
                kwargs["token_type_ids"] = torch.cat(
                    (
                        torch.zeros(batch_size, self.peft_config.num_virtual_tokens).to(self.device),
                        kwargs["token_type_ids"],
                    ),
                    dim=1,
                ).long()
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def _prefix_tuning_forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size)
        fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
        kwargs.update(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
                "past_key_values": past_key_values,
            }
        )
        if "past_key_values" in fwd_params:
            return self.base_model(labels=labels, **kwargs)
        else:
            transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
            fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
            if "past_key_values" not in fwd_params:
                raise ValueError("Model does not support past key values which are required for prefix tuning.")
            outputs = transformer_backbone_name(**kwargs)
            sequence_output = outputs[0]
            if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
                sequence_output = self.base_model.dropout(sequence_output)
            logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)

            loss = None
            loss = None
            if labels is not None:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output

            return TokenClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
