# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import warnings
from contextlib import contextmanager
from copy import deepcopy

import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin
import torch.nn.functional as F


from .tuners import (
    AdaLoraModel,
    AdaptionPromptModel,
    LoraModel,
    PrefixEncoder,
    PromptEmbedding,
    PromptEncoder,
    PromptEmbeddingLoRA,
    PromptEmbeddingab,
    PromptEmbeddingLoRAX,
    PromptEmbeddingLoRAXL,
    PromptEmbeddingLoRAXAB,
)
from .utils import (
    SAFETENSORS_WEIGHTS_NAME,
    TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
    WEIGHTS_NAME,
    PeftConfig,
    PeftType,
    PromptLearningConfig,
    TaskType,
    _set_adapter,
    _set_trainable,
    add_or_edit_model_card,
    get_peft_model_state_dict,
    hub_file_exists,
    set_peft_model_state_dict,
    shift_tokens_right,
)


PEFT_TYPE_TO_MODEL_MAPPING = {
    PeftType.LORA: LoraModel,
    PeftType.PROMPT_TUNING: PromptEmbedding,
    PeftType.P_TUNING: PromptEncoder,
    PeftType.PREFIX_TUNING: PrefixEncoder,
    PeftType.ADALORA: AdaLoraModel,
    PeftType.ADAPTION_PROMPT: AdaptionPromptModel,
    PeftType.PROMPT_TUNING_LORA: PromptEmbeddingLoRA,
    PeftType.PROMPT_TUNING_ab: PromptEmbeddingab,
    PeftType.PROMPT_TUNING_LORAX: PromptEmbeddingLoRAX,
    PeftType.PROMPT_TUNING_LORAXL: PromptEmbeddingLoRAXL,
    PeftType.PROMPT_TUNING_LORAXAB: PromptEmbeddingLoRAXAB,
}


class PeftModel(PushToHubMixin, torch.nn.Module):
    """
    Base model encompassing various Peft methods.

    Args:
        model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft.
        peft_config ([`PeftConfig`]): The configuration of the Peft model.


    **Attributes**:
        - **base_model** ([`~transformers.PreTrainedModel`]) -- The base transformer model used for Peft.
        - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
        - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
        saving the model.
        - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
        using [`PromptLearningConfig`].
        - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
        using [`PromptLearningConfig`].
        - **transformer_backbone_name** (`str`) -- The name of the transformer
        backbone in the base model if using [`PromptLearningConfig`].
        - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
        in the base model if using [`PromptLearningConfig`].
    """

    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        super().__init__()
        self.base_model = model
        self.config = self.base_model.config
        self.modules_to_save = None
        self.peft_config = {}
        self.active_adapter = adapter_name
        self.peft_type = peft_config.peft_type
        self.base_model_torch_dtype = getattr(model, "dtype", None)
        if not isinstance(peft_config, PromptLearningConfig):
            self.peft_config[adapter_name] = peft_config
            self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type](
                self.base_model, self.peft_config, adapter_name
            )
            self.set_additional_trainable_modules(peft_config, adapter_name)
        else:
            self.add_adapter(adapter_name, peft_config)

        if getattr(model, "is_gradient_checkpointing", True):
            model = self._prepare_model_for_gradient_checkpointing(model)

    def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
        r"""
        This function saves the adapter model and the adapter configuration files to a directory, so that it can be
        reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
        method.

        Args:
            save_directory (`str`):
                Directory where the adapter model and configuration files will be saved (will be created if it does not
                exist).
            kwargs (additional keyword arguments, *optional*):
                Additional keyword arguments passed along to the `push_to_hub` method.
        """
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        os.makedirs(save_directory, exist_ok=True)
        add_or_edit_model_card(save_directory)

        for adapter_name, peft_config in self.peft_config.items():
            # save only the trainable weights
            output_state_dict = get_peft_model_state_dict(
                self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name
            )
            output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
            os.makedirs(output_dir, exist_ok=True)

            if safe_serialization:
                safe_save_file(
                    output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"}
                )
            else:
                torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))

            # save the config and change the inference mode to `True`
            if peft_config.base_model_name_or_path is None:
                peft_config.base_model_name_or_path = (
                    self.base_model.__dict__.get("name_or_path", None)
                    if isinstance(peft_config, PromptLearningConfig)
                    else self.base_model.model.__dict__.get("name_or_path", None)
                )
            inference_mode = peft_config.inference_mode
            peft_config.inference_mode = True
            peft_config.save_pretrained(output_dir)
            peft_config.inference_mode = inference_mode

    @classmethod
    def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs):
        r"""
        Instantiate a [`LoraModel`] from a pretrained Lora configuration and weights.

        Args:
            model ([`~transformers.PreTrainedModel`]):
                The model to be adapted. The model should be initialized with the
                [`~transformers.PreTrainedModel.from_pretrained`] method from the 🤗 Transformers library.
            model_id (`str` or `os.PathLike`):
                The name of the Lora configuration to use. Can be either:
                    - A string, the `model id` of a Lora configuration hosted inside a model repo on the Hugging Face
                      Hub.
                    - A path to a directory containing a Lora configuration file saved using the `save_pretrained`
                      method (`./my_lora_config_directory/`).
        """
        from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING

        # load the config
        config = PEFT_TYPE_TO_CONFIG_MAPPING[
            PeftConfig._get_peft_type(
                model_id,
                subfolder=kwargs.get("subfolder", None),
                revision=kwargs.get("revision", None),
                cache_dir=kwargs.get("cache_dir", None),
            )
        ].from_pretrained(model_id, subfolder=kwargs.get("subfolder", None), **kwargs)

        if (getattr(model, "hf_device_map", None) is not None) and len(
            set(model.hf_device_map.values()).intersection({"cpu", "disk"})
        ) > 0:
            remove_hook_from_submodules(model)

        if isinstance(config, PromptLearningConfig) and is_trainable:
            raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
        else:
            config.inference_mode = not is_trainable

        if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
            model = cls(model, config, adapter_name)
        else:
            model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name)
        model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
        return model

    def _setup_prompt_encoder(self, adapter_name):
        config = self.peft_config[adapter_name]
        self.prompt_encoder = torch.nn.ModuleDict({})
        self.prompt_tokens = {}
        transformer_backbone = None
        for name, module in self.base_model.named_children():
            for param in module.parameters():
                param.requires_grad = False
            if isinstance(module, PreTrainedModel):
                # Make sure to freeze Tranformers model
                if transformer_backbone is None:
                    transformer_backbone = module
                    self.transformer_backbone_name = name

        if config.num_transformer_submodules is None:
            config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1

        for named_param, value in list(transformer_backbone.named_parameters()):
            if value.shape[0] == self.base_model.config.vocab_size:
                self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
                break

        if config.peft_type == PeftType.PROMPT_TUNING:
            prompt_encoder = PromptEmbedding(config, self.word_embeddings)
        elif config.peft_type == PeftType.PROMPT_TUNING_LORA:
            prompt_encoder = PromptEmbeddingLoRA(config, self.word_embeddings)
        elif config.peft_type == PeftType.PROMPT_TUNING_LORAX:
            prompt_encoder = PromptEmbeddingLoRAX(config, self.word_embeddings)
        elif config.peft_type == PeftType.PROMPT_TUNING_LORAXL:
            prompt_encoder = PromptEmbeddingLoRAXL(config, self.word_embeddings)
        elif config.peft_type == PeftType.PROMPT_TUNING_LORAXAB:
            prompt_encoder = PromptEmbeddingLoRAXAB(config, self.word_embeddings)
        elif config.peft_type == PeftType.P_TUNING:
            prompt_encoder = PromptEncoder(config)
        elif config.peft_type == PeftType.PREFIX_TUNING:
            prompt_encoder = PrefixEncoder(config)
        else:
            raise ValueError("Not supported")
        self.prompt_encoder.update(torch.nn.ModuleDict({adapter_name: prompt_encoder}))
        self.prompt_tokens[adapter_name] = torch.arange(
            config.num_virtual_tokens * config.num_transformer_submodules
        ).long()

    def _prepare_model_for_gradient_checkpointing(self, model):
        r"""
        Prepares the model for gradient checkpointing if necessary
        """
        if not (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)):
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
        return model

    def get_prompt_embedding_to_save(self, adapter_name):
        """
        Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type !=
        PeftType.LORA`.
        """
        prompt_encoder = self.prompt_encoder[adapter_name]
        prompt_tokens = (
            self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device)
        )
        if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING:
            prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens]
        prompt_embeddings = prompt_encoder(prompt_tokens)
        return prompt_embeddings[0].detach().cpu()

    def get_prompt(self, batch_size):
        """
        Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`.
        """
        peft_config = self.active_peft_config
        prompt_encoder = self.prompt_encoder[self.active_adapter]
        prompt_tokens = (
            self.prompt_tokens[self.active_adapter]
            .unsqueeze(0)
            .expand(batch_size, -1)
            .to(prompt_encoder.embedding.weight.device)
        )
        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
            if peft_config.inference_mode:
                past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                past_key_values = prompt_encoder(prompt_tokens)
            past_key_values = past_key_values.view(
                batch_size,
                peft_config.num_virtual_tokens,
                peft_config.num_layers * 2,
                peft_config.num_attention_heads,
                peft_config.token_dim // peft_config.num_attention_heads,
            )
            if peft_config.num_transformer_submodules == 2:
                past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
            past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
                peft_config.num_transformer_submodules * 2
            )
            if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
                post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
                past_key_values = post_process_fn(past_key_values)
            return past_key_values
        else:
            if peft_config.inference_mode:
                prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                prompts = prompt_encoder(prompt_tokens)
            return prompts

    def print_trainable_parameters(self):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in self.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params
        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.base_model, name)

    def forward(self, *args, **kwargs):
        """
        Forward pass of the model.
        """
        return self.get_base_model()(*args, **kwargs)

    @contextmanager
    def disable_adapter(self):
        """
        Disables the adapter module.
        """
        try:
            if isinstance(self.peft_config, PromptLearningConfig):
                old_forward = self.forward
                self.forward = self.base_model.forward
            else:
                self.base_model.disable_adapter_layers()
            yield
        finally:
            if isinstance(self.peft_config, PromptLearningConfig):
                self.forward = old_forward
            else:
                self.base_model.enable_adapter_layers()

    def get_base_model(self):
        """
        Returns the base model.
        """
        return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model

    def add_adapter(self, adapter_name, peft_config):
        if peft_config.peft_type != self.peft_type:
            raise ValueError(
                f"Cannot combine adapters with different peft types. "
                f"Found {self.peft_type} and {peft_config.peft_type}."
            )
        self.peft_config[adapter_name] = peft_config
        if isinstance(peft_config, PromptLearningConfig):
            self._setup_prompt_encoder(adapter_name)
        else:
            self.base_model.add_adapter(adapter_name, peft_config)

        self.set_additional_trainable_modules(peft_config, adapter_name)

    def set_additional_trainable_modules(self, peft_config, adapter_name):
        if getattr(peft_config, "modules_to_save", None) is not None:
            if self.modules_to_save is None:
                self.modules_to_save = set(peft_config.modules_to_save)
            else:
                self.modules_to_save.update(peft_config.modules_to_save)
            _set_trainable(self, adapter_name)

    @classmethod
    def _split_kwargs(cls, kwargs):
        hf_hub_download_kwargs = {}
        other_kwargs = {}

        for key, value in kwargs.items():
            if key in inspect.signature(hf_hub_download).parameters:
                hf_hub_download_kwargs[key] = value
            else:
                other_kwargs[key] = value

        return hf_hub_download_kwargs, other_kwargs

    def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):
        from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING

        hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)

        if adapter_name not in self.peft_config:
            # load the config
            peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[
                PeftConfig._get_peft_type(
                    model_id,
                    subfolder=kwargs.get("subfolder", None),
                    revision=kwargs.get("revision", None),
                    cache_dir=kwargs.get("cache_dir", None),
                )
            ].from_pretrained(
                model_id,
                subfolder=kwargs.get("subfolder", None),
                revision=kwargs.get("revision", None),
                cache_dir=kwargs.get("cache_dir", None),
            )
            if isinstance(peft_config, PromptLearningConfig) and is_trainable:
                raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
            else:
                peft_config.inference_mode = not is_trainable
            self.add_adapter(adapter_name, peft_config)

        # load weights if any
        path = os.path.join(model_id, kwargs["subfolder"]) if kwargs.get("subfolder", None) is not None else model_id

        if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
            filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
            use_safetensors = True
        elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
            filename = os.path.join(path, WEIGHTS_NAME)
            use_safetensors = False
        else:
            has_remote_safetensors_file = hub_file_exists(
                model_id, SAFETENSORS_WEIGHTS_NAME, revision=kwargs.get("revision", None)
            )
            use_safetensors = has_remote_safetensors_file

            if has_remote_safetensors_file:
                # Priority 1: load safetensors weights
                filename = hf_hub_download(
                    model_id,
                    SAFETENSORS_WEIGHTS_NAME,
                    subfolder=kwargs.get("subfolder", None),
                    **hf_hub_download_kwargs,
                )
            else:
                try:
                    filename = hf_hub_download(
                        model_id, WEIGHTS_NAME, subfolder=kwargs.get("subfolder", None), **hf_hub_download_kwargs
                    )
                except EntryNotFoundError:
                    raise ValueError(
                        f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
                        f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}."
                    )

        if use_safetensors:
            adapters_weights = safe_load_file(filename, device="cuda" if torch.cuda.is_available() else "cpu")
        else:
            adapters_weights = torch.load(
                filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
            )

        # load the weights into the model
        load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
        if (
            (getattr(self, "hf_device_map", None) is not None)
            and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
            and len(self.peft_config) == 1
        ):
            device_map = kwargs.get("device_map", "auto")
            max_memory = kwargs.get("max_memory", None)
            offload_dir = kwargs.get("offload_folder", None)
            offload_index = kwargs.get("offload_index", None)

            dispatch_model_kwargs = {}
            # Safety checker for previous `accelerate` versions
            # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
            if "offload_index" in inspect.signature(dispatch_model).parameters:
                dispatch_model_kwargs["offload_index"] = offload_index

            no_split_module_classes = self._no_split_modules

            if device_map != "sequential":
                max_memory = get_balanced_memory(
                    self,
                    max_memory=max_memory,
                    no_split_module_classes=no_split_module_classes,
                    low_zero=(device_map == "balanced_low_0"),
                )
            if isinstance(device_map, str):
                device_map = infer_auto_device_map(
                    self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
                )
            dispatch_model(
                self,
                device_map=device_map,
                offload_dir=offload_dir,
                **dispatch_model_kwargs,
            )
            hook = AlignDevicesHook(io_same_device=True)
            if isinstance(self.peft_config[adapter_name], PromptLearningConfig):
                remove_hook_from_submodules(self.prompt_encoder)
            add_hook_to_module(self.get_base_model(), hook)

        # Set model in evaluation mode to deactivate Dropout modules by default
        if not is_trainable:
            self.eval()
        return load_result

    def set_adapter(self, adapter_name):
        """
        Sets the active adapter.
        """
        if adapter_name not in self.peft_config:
            raise ValueError(f"Adapter {adapter_name} not found.")
        self.active_adapter = adapter_name
        if not isinstance(self.peft_config[adapter_name], PromptLearningConfig):
            self.base_model.set_adapter(adapter_name)
        _set_adapter(self, adapter_name)

    @property
    def active_peft_config(self):
        return self.peft_config[self.active_adapter]


class PeftModelForSequenceClassification(PeftModel):
    """
    Peft model for sequence classification tasks.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.

    **Attributes**:
        - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
        - **cls_layer_name** (`str`) -- The name of the classification layer.

    Example:

        ```py
        >>> from transformers import AutoModelForSequenceClassification
        >>> from peft import PeftModelForSequenceClassification, get_peft_config

        >>> config = {
        ...     "peft_type": "PREFIX_TUNING",
        ...     "task_type": "SEQ_CLS",
        ...     "inference_mode": False,
        ...     "num_virtual_tokens": 20,
        ...     "token_dim": 768,
        ...     "num_transformer_submodules": 1,
        ...     "num_attention_heads": 12,
        ...     "num_layers": 12,
        ...     "encoder_hidden_size": 768,
        ...     "prefix_projection": False,
        ...     "postprocess_past_key_value_function": None,
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
        >>> peft_model = PeftModelForSequenceClassification(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
        ```
    """

    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        if self.modules_to_save is None:
            self.modules_to_save = {"classifier", "score"}
        else:
            self.modules_to_save.update({"classifier", "score"})

        for name, _ in self.base_model.named_children():
            if any(module_name in name for module_name in self.modules_to_save):
                self.cls_layer_name = name
                break

        # to make sure classifier layer is trainable
        _set_trainable(self, adapter_name)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        peft_config = self.active_peft_config
        if not isinstance(peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
        else:
            if kwargs.get("token_type_ids", None) is not None:
                kwargs["token_type_ids"] = torch.cat(
                    (
                        torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device),
                        kwargs["token_type_ids"],
                    ),
                    dim=1,
                ).long()
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def _prefix_tuning_forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size)
        fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
        kwargs.update(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
                "past_key_values": past_key_values,
            }
        )
        if "past_key_values" in fwd_params:
            return self.base_model(labels=labels, **kwargs)
        else:
            transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
            fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
            if "past_key_values" not in fwd_params:
                raise ValueError("Model does not support past key values which are required for prefix tuning.")
            outputs = transformer_backbone_name(**kwargs)
            pooled_output = outputs[1] if len(outputs) > 1 else outputs[0]
            if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
                pooled_output = self.base_model.dropout(pooled_output)
            logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output)

            loss = None
            if labels is not None:
                if self.config.problem_type is None:
                    if self.base_model.num_labels == 1:
                        self.config.problem_type = "regression"
                    elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                        self.config.problem_type = "single_label_classification"
                    else:
                        self.config.problem_type = "multi_label_classification"

                if self.config.problem_type == "regression":
                    loss_fct = MSELoss()
                    if self.base_model.num_labels == 1:
                        loss = loss_fct(logits.squeeze(), labels.squeeze())
                    else:
                        loss = loss_fct(logits, labels)
                elif self.config.problem_type == "single_label_classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1))
                elif self.config.problem_type == "multi_label_classification":
                    loss_fct = BCEWithLogitsLoss()
                    loss = loss_fct(logits, labels)
            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output

            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )


class PeftModelForCausalLM(PeftModel):
    """
    Peft model for causal language modeling.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.


    Example:

        ```py
        >>> from transformers import AutoModelForCausalLM
        >>> from peft import PeftModelForCausalLM, get_peft_config

        >>> config = {
        ...     "peft_type": "PREFIX_TUNING",
        ...     "task_type": "CAUSAL_LM",
        ...     "inference_mode": False,
        ...     "num_virtual_tokens": 20,
        ...     "token_dim": 1280,
        ...     "num_transformer_submodules": 1,
        ...     "num_attention_heads": 20,
        ...     "num_layers": 36,
        ...     "encoder_hidden_size": 1280,
        ...     "prefix_projection": False,
        ...     "postprocess_past_key_value_function": None,
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForCausalLM.from_pretrained("gpt2-large")
        >>> peft_model = PeftModelForCausalLM(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544
        ```
    """

    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        peft_config = self.active_peft_config
        if not isinstance(peft_config, PromptLearningConfig):
            if self.base_model.config.model_type == "mpt":
                if inputs_embeds is not None:
                    raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORA:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A.to(inputs_embeds.device)
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B.to(inputs_embeds.device)
                scaling = self.prompt_encoder[self.active_adapter].scaling
                inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A ) @ lora_embedding_B
                #inputs_embeds += scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAX:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A.to(inputs_embeds.device)
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B.to(inputs_embeds.device)
                #lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C.to(inputs_embeds.device)
                #lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D.to(inputs_embeds.device)
                scaling = self.prompt_encoder[self.active_adapter].scaling


                lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a.to(inputs_embeds.device)
                lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b.to(inputs_embeds.device)
                #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                #inputs_embeds =inputs_embeds +  lora_embedding_B(torch.nn.ReLU(lora_embedding_A(inputs_embeds))) #temp
                #(inputs_embeds @ lora_embedding_A @ lora_embedding_B)
                temp = inputs_embeds @ lora_embedding_A + lora_embedding_a
                if self.training:
                    temp = F.dropout(temp, p=0.1)
                    print(877)
                inputs_embeds =inputs_embeds +  torch.nn.functional.relu(temp) @ lora_embedding_B + lora_embedding_b #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
                #inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_B + lora_embedding_b  #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
                #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B)).repeat(batch_size, 1, 1)
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXL:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                lora_embedding_H = self.prompt_encoder[self.active_adapter].lora_embedding_H
                lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a
                lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b
                lora_embedding_h = self.prompt_encoder[self.active_adapter].lora_embedding_h
                gamma = self.prompt_encoder[self.active_adapter].gamma
                beta = self.prompt_encoder[self.active_adapter].beta
                scaling = self.prompt_encoder[self.active_adapter].scaling
                eps = 1e-12

                #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))#.repeat(batch_size, 1, 1)\
                #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                temp = torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_H + lora_embedding_h
                mean = temp.mean(-1, keepdim = True)
                var = temp.var(-1, unbiased = False, keepdim = True)
                temp = (temp - mean )/torch.sqrt(var + eps)
                temp = gamma * temp + beta

                inputs_embeds =inputs_embeds +  (temp@lora_embedding_B + lora_embedding_b)
                #inputs_embeds += scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXAB:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A.to(inputs_embeds.device)
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B.to(inputs_embeds.device)
                lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C.to(inputs_embeds.device)
                lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D.to(inputs_embeds.device)
                scaling = self.prompt_encoder[self.active_adapter].scaling


                lora_embedding_c = self.prompt_encoder[self.active_adapter].lora_embedding_c.to(inputs_embeds.device)
                lora_embedding_d = self.prompt_encoder[self.active_adapter].lora_embedding_d.to(inputs_embeds.device)
                #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                #inputs_embeds =inputs_embeds +  lora_embedding_B(torch.nn.ReLU(lora_embedding_A(inputs_embeds))) #temp
                #(inputs_embeds @ lora_embedding_A @ lora_embedding_B)
                inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_C + lora_embedding_c) @ lora_embedding_D + lora_embedding_d  + scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
                #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B)).repeat(batch_size, 1, 1)
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)        
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            # concat prompt labels
            if labels is not None:
                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def generate(self, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        if hasattr(self.base_model, "model"):
            self.base_model.model.generation_config = self.generation_config
        else:
            self.base_model.generation_config = self.generation_config
        try:
            if not isinstance(peft_config, PromptLearningConfig):
                outputs = self.base_model.generate(**kwargs)
            else:
                if "input_ids" not in kwargs:
                    raise ValueError("input_ids must be provided for Peft model generation")
                # For gpt2 models, we construct postion_ids on the fly by using attention mask, and position ids need to match input_shape.
                # for prefix tuning, input shape is determined using `input_ids`. Thus we should not expand 'attention_mask' here
                # for prompt tuning input_ids is not passed but a concatenated input_embeds is passed. Thus attention_mask needs to be of same size of num_virtual_tokens + input_ids
                if kwargs.get("attention_mask", None) is not None and peft_config.peft_type in [
                    PeftType.PROMPT_TUNING,
                    PeftType.P_TUNING,
                    PeftType.PROMPT_TUNING_LORA,
                    PeftType.PROMPT_TUNING_LORAX,
                    PeftType.PROMPT_TUNING_LORAXL,
                    PeftType.PROMPT_TUNING_LORAXAB,
                ]:
                    # concat prompt attention mask
                    prefix_attention_mask = torch.ones(
                        kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
                    ).to(kwargs["input_ids"].device)
                    kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                if kwargs.get("position_ids", None) is not None:
                    warnings.warn(
                        "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
                    )
                    kwargs["position_ids"] = None
                if kwargs.get("token_type_ids", None) is not None:
                    warnings.warn(
                        "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
                    )
                    kwargs["token_type_ids"] = None
                outputs = self.base_model.generate(**kwargs)
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            return outputs

    def prepare_inputs_for_generation(self, *args, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if isinstance(peft_config, PromptLearningConfig):
            if peft_config.peft_type == PeftType.PREFIX_TUNING:
                prefix_attention_mask = torch.ones(
                    model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
                ).to(model_kwargs["input_ids"].device)
                model_kwargs["attention_mask"] = torch.cat(
                    (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
                )

            if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
                past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])

                if self.base_model_torch_dtype is not None:
                    # handle the case for Bloom where it outputs tuple of tuples
                    if isinstance(past_key_values[0], tuple):
                        past_key_values = tuple(
                            tuple(
                                past_key_value.to(self.base_model_torch_dtype)
                                for past_key_value in past_key_value_tuple
                            )
                            for past_key_value_tuple in past_key_values
                        )
                    else:
                        past_key_values = tuple(
                            past_key_value.to(self.base_model_torch_dtype) for past_key_value in past_key_values
                        )

                model_kwargs["past_key_values"] = past_key_values
            else:
                if model_kwargs["past_key_values"] is None:
                    if peft_config.peft_type == PeftType.PROMPT_TUNING_LORA: # (zshi)
                        batch_size = model_kwargs["input_ids"].shape[0]
                        inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                        lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A.to(inputs_embeds.device)
                        lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B.to(inputs_embeds.device)
                        scaling = self.prompt_encoder[self.active_adapter].scaling
                        #inputs_embeds += scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
                        inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A ) @ lora_embedding_B
                        prompts = self.get_prompt(batch_size=batch_size)
                        prompts = prompts.to(inputs_embeds.dtype)
                        generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                        model_kwargs["inputs_embeds"] = generation_inputs
                        model_kwargs["input_ids"] = None
                    if peft_config.peft_type == PeftType.PROMPT_TUNING_LORAX: # (zshi)
                        batch_size = model_kwargs["input_ids"].shape[0]
                        inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                        lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A.to(inputs_embeds.device)
                        lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B.to(inputs_embeds.device)
                        #lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C.to(inputs_embeds.device)
                        #lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D.to(inputs_embeds.device)
                        lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a.to(inputs_embeds.device)
                        lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b.to(inputs_embeds.device)
                        scaling = self.prompt_encoder[self.active_adapter].scaling
                        #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                        #inputs_embeds =inputs_embeds +  lora_embedding_B(torch.nn.ReLU(lora_embedding_A(inputs_embeds))) #temp
                        #inputs_embeds += (inputs_embeds @ lora_embedding_A @ lora_embedding_B)
                        #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B)).repeat(batch_size, 1, 1)
                        temp = inputs_embeds @ lora_embedding_A + lora_embedding_a
                        if self.training:
                            temp = F.dropout(temp, p=0.1)
                            print(1069)
                        inputs_embeds =inputs_embeds +  torch.nn.functional.relu(temp) @ lora_embedding_B + lora_embedding_b #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
                        #inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_B + lora_embedding_b  #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
                        prompts = self.get_prompt(batch_size=batch_size)
                        prompts = prompts.to(inputs_embeds.dtype)
                        generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                        model_kwargs["inputs_embeds"] = generation_inputs
                        model_kwargs["input_ids"] = None
                    if peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXL: # (zshi)
                        batch_size = model_kwargs["input_ids"].shape[0]
                        inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                        lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                        lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                        lora_embedding_H = self.prompt_encoder[self.active_adapter].lora_embedding_H
                        lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a
                        lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b
                        lora_embedding_h = self.prompt_encoder[self.active_adapter].lora_embedding_h
                        gamma = self.prompt_encoder[self.active_adapter].gamma
                        beta = self.prompt_encoder[self.active_adapter].beta
                        scaling = self.prompt_encoder[self.active_adapter].scaling
                        eps = 1e-12

                        #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))#.repeat(batch_size, 1, 1)\
                        #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                        temp = torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_H + lora_embedding_h
                        mean = temp.mean(-1, keepdim = True)
                        var = temp.var(-1, unbiased = False, keepdim = True)
                        temp = (temp - mean )/torch.sqrt(var + eps)
                        temp = gamma * temp + beta

                        inputs_embeds =inputs_embeds +  (temp@lora_embedding_B + lora_embedding_b)
                        
                        prompts = self.get_prompt(batch_size=batch_size)
                        prompts = prompts.to(inputs_embeds.dtype)
                        generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                        model_kwargs["inputs_embeds"] = generation_inputs
                        model_kwargs["input_ids"] = None
                    if peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXAB: # (zshi)
                        batch_size = model_kwargs["input_ids"].shape[0]
                        inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                        lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A.to(inputs_embeds.device)
                        lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B.to(inputs_embeds.device)
                        lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C.to(inputs_embeds.device)
                        lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D.to(inputs_embeds.device)
                        lora_embedding_c = self.prompt_encoder[self.active_adapter].lora_embedding_c.to(inputs_embeds.device)
                        lora_embedding_d = self.prompt_encoder[self.active_adapter].lora_embedding_d.to(inputs_embeds.device)
                        scaling = self.prompt_encoder[self.active_adapter].scaling
                        #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                        #inputs_embeds =inputs_embeds +  lora_embedding_B(torch.nn.ReLU(lora_embedding_A(inputs_embeds))) #temp
                        #inputs_embeds += (inputs_embeds @ lora_embedding_A @ lora_embedding_B)
                        #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B)).repeat(batch_size, 1, 1)
                        inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_C + lora_embedding_c) @ lora_embedding_D + lora_embedding_d  + scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
                        prompts = self.get_prompt(batch_size=batch_size)
                        prompts = prompts.to(inputs_embeds.dtype)
                        generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                        model_kwargs["inputs_embeds"] = generation_inputs
                        model_kwargs["input_ids"] = None
                    else:
                        inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
                        prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
                        prompts = prompts.to(inputs_embeds.dtype)
                        model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
                        model_kwargs["input_ids"] = None

        return model_kwargs


class PeftModelForSeq2SeqLM(PeftModel):
    """
    Peft model for sequence-to-sequence language modeling.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.


    Example:

        ```py
        >>> from transformers import AutoModelForSeq2SeqLM
        >>> from peft import PeftModelForSeq2SeqLM, get_peft_config

        >>> config = {
        ...     "peft_type": "LORA",
        ...     "task_type": "SEQ_2_SEQ_LM",
        ...     "inference_mode": False,
        ...     "r": 8,
        ...     "target_modules": ["q", "v"],
        ...     "lora_alpha": 32,
        ...     "lora_dropout": 0.1,
        ...     "fan_in_fan_out": False,
        ...     "enable_lora": None,
        ...     "bias": "none",
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> peft_model = PeftModelForSeq2SeqLM(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
        ```
    """

    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        decoder_inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        peft_config = self.active_peft_config
        if not isinstance(peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                decoder_inputs_embeds=decoder_inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if decoder_attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                decoder_attention_mask.device
            )
            decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "decoder_attention_mask": decoder_attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_ids=input_ids, decoder_input_ids=decoder_input_ids, past_key_values=past_key_values, **kwargs
            )
        elif peft_config.peft_type == PeftType.PROMPT_TUNING:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)

            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORA:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                scaling = self.prompt_encoder[self.active_adapter].scaling
                #inputs_embeds += scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
                inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A ) @ lora_embedding_B
            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAX:
            
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                #lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C
                #lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D
                lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a
                lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b
                scaling = self.prompt_encoder[self.active_adapter].scaling

                #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))#.repeat(batch_size, 1, 1)\
                #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                temp = inputs_embeds @ lora_embedding_A + lora_embedding_a
                if self.training:
                    temp = F.dropout(temp, p=0.1)
                    #print(1293)
                inputs_embeds =inputs_embeds +  torch.nn.functional.relu(temp) @ lora_embedding_B + lora_embedding_b #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXAB:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C
                lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D
                lora_embedding_c = self.prompt_encoder[self.active_adapter].lora_embedding_c
                lora_embedding_d = self.prompt_encoder[self.active_adapter].lora_embedding_d
                scaling = self.prompt_encoder[self.active_adapter].scaling

                #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))#.repeat(batch_size, 1, 1)\
                #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_C + lora_embedding_c) @ lora_embedding_D + lora_embedding_d + scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXL:
            if inputs_embeds is None:
                batch_size = input_ids.shape[0]
                inputs_embeds = self.word_embeddings(input_ids)
                lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                lora_embedding_H = self.prompt_encoder[self.active_adapter].lora_embedding_H
                lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a
                lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b
                lora_embedding_h = self.prompt_encoder[self.active_adapter].lora_embedding_h
                gamma = self.prompt_encoder[self.active_adapter].gamma
                beta = self.prompt_encoder[self.active_adapter].beta
                scaling = self.prompt_encoder[self.active_adapter].scaling
                eps = 1e-12

                #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))#.repeat(batch_size, 1, 1)\
                #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                temp = torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_H + lora_embedding_h
                mean = temp.mean(-1, keepdim = True)
                var = temp.var(-1, unbiased = False, keepdim = True)
                temp = (temp - mean )/torch.sqrt(var + eps)
                temp = gamma * temp + beta

                inputs_embeds =inputs_embeds +  (temp@lora_embedding_B + lora_embedding_b)

            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            if decoder_inputs_embeds is None and decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
                decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)

            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
            # concat prompt labels
            if labels is not None:
                if peft_config.num_transformer_submodules == 1:
                    kwargs["labels"] = labels
                elif peft_config.num_transformer_submodules == 2:
                    prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                    kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
            if peft_config.num_transformer_submodules == 1:
                return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
            elif peft_config.num_transformer_submodules == 2:
                decoder_inputs_embeds = torch.cat(
                    (prompts[:, peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
                )
                return self.base_model(
                    inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
                )

    def generate(self, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
            self._prepare_encoder_decoder_kwargs_for_generation
        )
        try:
            if not isinstance(peft_config, PromptLearningConfig):
                outputs = self.base_model.generate(**kwargs)
            else:
                if "input_ids" not in kwargs:
                    raise ValueError("input_ids must be provided for Peft model generation")
                if kwargs.get("position_ids", None) is not None:
                    warnings.warn(
                        "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
                    )
                    kwargs["position_ids"] = None
                if kwargs.get("token_type_ids", None) is not None:
                    warnings.warn(
                        "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
                    )
                    kwargs["token_type_ids"] = None

                if peft_config.peft_type == PeftType.PREFIX_TUNING:
                    outputs = self.base_model.generate(**kwargs)
                elif peft_config.peft_type == PeftType.PROMPT_TUNING:
                    kwargs = deepcopy(kwargs)

                    if "encoder_outputs" in kwargs:
                        del kwargs["encoder_ouputs"]
                        warnings.warn(
                            "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
                        )

                    input_ids = kwargs.pop("input_ids")
                    inputs_embeds = self.word_embeddings(input_ids)
                    batch_size = inputs_embeds.shape[0]
                    prompts = self.get_prompt(batch_size=batch_size)
                    prompts = prompts.to(inputs_embeds.dtype)
                    generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                    kwargs["inputs_embeds"] = generation_inputs

                    if "attention_mask" in kwargs:
                        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                            kwargs["attention_mask"].device
                        )
                        kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                    return self.base_model.generate(**kwargs)
                elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORA:
                    kwargs = deepcopy(kwargs)

                    if "encoder_outputs" in kwargs:
                        del kwargs["encoder_ouputs"]
                        warnings.warn(
                            "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
                        )

                    input_ids = kwargs.pop("input_ids")
                    batch_size = input_ids.shape[0]
                    inputs_embeds = self.word_embeddings(input_ids)
                    lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                    lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                    scaling = self.prompt_encoder[self.active_adapter].scaling
                    #inputs_embeds += scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
                    inputs_embeds =inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A ) @ lora_embedding_B
                    batch_size = inputs_embeds.shape[0]
                    prompts = self.get_prompt(batch_size=batch_size)
                    prompts = prompts.to(inputs_embeds.dtype)
                    generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                    kwargs["inputs_embeds"] = generation_inputs

                    if "attention_mask" in kwargs:
                        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                            kwargs["attention_mask"].device
                        )
                        kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                    return self.base_model.generate(**kwargs)
                elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAX:
                    kwargs = deepcopy(kwargs)

                    if "encoder_outputs" in kwargs:
                        del kwargs["encoder_ouputs"]
                        warnings.warn(
                            "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
                        )

                    input_ids = kwargs.pop("input_ids")
                    batch_size = input_ids.shape[0]
                    inputs_embeds = self.word_embeddings(input_ids)
                    lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                    lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                    #lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C
                    #lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D
                    lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a
                    lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b
                    scaling = self.prompt_encoder[self.active_adapter].scaling

                    #inputs_embeds += (inputs_embeds @ lora_embedding_A @ lora_embedding_B)
                    #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))
                    #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                    temp = inputs_embeds @ lora_embedding_A + lora_embedding_a
                    if self.training:
                        temp = F.dropout(temp, p=0.1)
                        print(1513)
                    inputs_embeds =inputs_embeds +  torch.nn.functional.relu(temp) @ lora_embedding_B + lora_embedding_b #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
                    #inputs_embeds = inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_B + lora_embedding_b #+ scaling * (lora_embedding_C @ lora_embedding_D).repeat(batch_size, 1, 1)
                    batch_size = inputs_embeds.shape[0]
                    prompts = self.get_prompt(batch_size=batch_size)
                    prompts = prompts.to(inputs_embeds.dtype)
                    generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                    kwargs["inputs_embeds"] = generation_inputs

                    if "attention_mask" in kwargs:
                        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                            kwargs["attention_mask"].device
                        )
                        kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                    return self.base_model.generate(**kwargs)
                
                elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXAB:
                    kwargs = deepcopy(kwargs)

                    if "encoder_outputs" in kwargs:
                        del kwargs["encoder_ouputs"]
                        warnings.warn(
                            "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
                        )

                    input_ids = kwargs.pop("input_ids")
                    batch_size = input_ids.shape[0]
                    inputs_embeds = self.word_embeddings(input_ids)
                    lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                    lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                    lora_embedding_C = self.prompt_encoder[self.active_adapter].lora_embedding_C
                    lora_embedding_D = self.prompt_encoder[self.active_adapter].lora_embedding_D
                    lora_embedding_c = self.prompt_encoder[self.active_adapter].lora_embedding_c
                    lora_embedding_d = self.prompt_encoder[self.active_adapter].lora_embedding_d
                    scaling = self.prompt_encoder[self.active_adapter].scaling

                    #inputs_embeds += (inputs_embeds @ lora_embedding_A @ lora_embedding_B)
                    #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))
                    #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                    inputs_embeds = inputs_embeds +  torch.nn.functional.relu(inputs_embeds @ lora_embedding_C + lora_embedding_c) @ lora_embedding_D + lora_embedding_d + scaling * (lora_embedding_A @ lora_embedding_B).repeat(batch_size, 1, 1)
                    batch_size = inputs_embeds.shape[0]
                    prompts = self.get_prompt(batch_size=batch_size)
                    prompts = prompts.to(inputs_embeds.dtype)
                    generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                    kwargs["inputs_embeds"] = generation_inputs

                    if "attention_mask" in kwargs:
                        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                            kwargs["attention_mask"].device
                        )
                        kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                    return self.base_model.generate(**kwargs)                    
                
                elif peft_config.peft_type == PeftType.PROMPT_TUNING_LORAXL:
                    kwargs = deepcopy(kwargs)

                    if "encoder_outputs" in kwargs:
                        del kwargs["encoder_ouputs"]
                        warnings.warn(
                            "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
                        )

                    input_ids = kwargs.pop("input_ids")
                    batch_size = input_ids.shape[0]
                    inputs_embeds = self.word_embeddings(input_ids)
                    lora_embedding_A = self.prompt_encoder[self.active_adapter].lora_embedding_A
                    lora_embedding_B = self.prompt_encoder[self.active_adapter].lora_embedding_B
                    lora_embedding_H = self.prompt_encoder[self.active_adapter].lora_embedding_H
                    lora_embedding_a = self.prompt_encoder[self.active_adapter].lora_embedding_a
                    lora_embedding_b = self.prompt_encoder[self.active_adapter].lora_embedding_b
                    lora_embedding_h = self.prompt_encoder[self.active_adapter].lora_embedding_h
                    gamma = self.prompt_encoder[self.active_adapter].gamma
                    beta = self.prompt_encoder[self.active_adapter].beta
                    scaling = self.prompt_encoder[self.active_adapter].scaling
                    eps = 1e-12

                    #inputs_embeds += torch.matmul(inputs_embeds,(lora_embedding_A @ lora_embedding_B))#.repeat(batch_size, 1, 1)\
                    #temp = inputs_embeds @ (lora_embedding_A @ lora_embedding_B)# + lora_embedding_B
                    temp = torch.nn.functional.relu(inputs_embeds @ lora_embedding_A + lora_embedding_a) @ lora_embedding_H + lora_embedding_h
                    mean = temp.mean(-1, keepdim = True)
                    var = temp.var(-1, unbiased = False, keepdim = True)
                    temp = (temp - mean )/torch.sqrt(var + eps)
                    temp = gamma * temp + beta

                    inputs_embeds =inputs_embeds +  (temp@lora_embedding_B + lora_embedding_b)

                    prompts = self.get_prompt(batch_size=batch_size)
                    prompts = prompts.to(inputs_embeds.dtype)
                    generation_inputs = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                    kwargs["inputs_embeds"] = generation_inputs

                    if "attention_mask" in kwargs:
                        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                            kwargs["attention_mask"].device
                        )
                        kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                    return self.base_model.generate(**kwargs)

                else:
                    raise NotImplementedError
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            return outputs

    def prepare_inputs_for_generation(self, *args, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
            batch_size = model_kwargs["decoder_input_ids"].shape[0]
            past_key_values = self.get_prompt(batch_size)
            if self.base_model_torch_dtype is not None:
                # handle the case for Bloom where it outputs tuple of tuples
                if isinstance(past_key_values[0], tuple):
                    past_key_values = tuple(
                        tuple(
                            past_key_value.to(self.base_model_torch_dtype) for past_key_value in past_key_value_tuple
                        )
                        for past_key_value_tuple in past_key_values
                    )
                else:
                    past_key_values = tuple(
                        past_key_value.to(self.base_model_torch_dtype) for past_key_value in past_key_values
                    )
            model_kwargs["past_key_values"] = past_key_values

        return model_kwargs


class PeftModelForTokenClassification(PeftModel):
    """
    Peft model for token classification tasks.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.

    **Attributes**:
        - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
        - **cls_layer_name** (`str`) -- The name of the classification layer.

    Example:

        ```py
        >>> from transformers import AutoModelForSequenceClassification
        >>> from peft import PeftModelForTokenClassification, get_peft_config

        >>> config = {
        ...     "peft_type": "PREFIX_TUNING",
        ...     "task_type": "TOKEN_CLS",
        ...     "inference_mode": False,
        ...     "num_virtual_tokens": 20,
        ...     "token_dim": 768,
        ...     "num_transformer_submodules": 1,
        ...     "num_attention_heads": 12,
        ...     "num_layers": 12,
        ...     "encoder_hidden_size": 768,
        ...     "prefix_projection": False,
        ...     "postprocess_past_key_value_function": None,
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForTokenClassification.from_pretrained("bert-base-cased")
        >>> peft_model = PeftModelForTokenClassification(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
        ```
    """

    def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        if self.modules_to_save is None:
            self.modules_to_save = {"classifier", "score"}
        else:
            self.modules_to_save.update({"classifier", "score"})

        for name, _ in self.base_model.named_children():
            if any(module_name in name for module_name in self.modules_to_save):
                self.cls_layer_name = name
                break

        # to make sure classifier layer is trainable
        _set_trainable(self, adapter_name)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        peft_config = self.active_peft_config
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if not isinstance(peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                labels=labels,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
        else:
            if kwargs.get("token_type_ids", None) is not None:
                kwargs["token_type_ids"] = torch.cat(
                    (
                        torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device),
                        kwargs["token_type_ids"],
                    ),
                    dim=1,
                ).long()
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def _prefix_tuning_forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size)
        fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
        kwargs.update(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
                "past_key_values": past_key_values,
            }
        )
        if "past_key_values" in fwd_params:
            return self.base_model(labels=labels, **kwargs)
        else:
            transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
            fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
            if "past_key_values" not in fwd_params:
                raise ValueError("Model does not support past key values which are required for prefix tuning.")
            outputs = transformer_backbone_name(**kwargs)
            sequence_output = outputs[0]
            if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
                sequence_output = self.base_model.dropout(sequence_output)
            logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)

            loss = None
            if labels is not None:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output

            return TokenClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )


class PeftModelForQuestionAnswering(PeftModel):
    """
    Peft model for extractive question answering.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.

    **Attributes**:
        - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
        - **cls_layer_name** (`str`) -- The name of the classification layer.

    Example:

        ```py
        >>> from transformers import AutoModelForQuestionAnswering
        >>> from peft import PeftModelForQuestionAnswering, get_peft_config

        >>> config = {
        ...     "peft_type": "LORA",
        ...     "task_type": "QUESTION_ANS",
        ...     "inference_mode": False,
        ...     "r": 16,
        ...     "target_modules": ["query", "value"],
        ...     "lora_alpha": 32,
        ...     "lora_dropout": 0.05,
        ...     "fan_in_fan_out": False,
        ...     "bias": "none",
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForQuestionAnswering.from_pretrained("bert-base-cased")
        >>> peft_model = PeftModelForQuestionAnswering(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 592900 || all params: 108312580 || trainable%: 0.5473971721475013
        ```
    """

    def __init__(self, model, peft_config: PeftConfig = None, adapter_name="default"):
        super().__init__(model, peft_config, adapter_name)
        if self.modules_to_save is None:
            self.modules_to_save = {"qa_outputs"}
        else:
            self.modules_to_save.update({"qa_outputs"})

        for name, _ in self.base_model.named_children():
            if any(module_name in name for module_name in self.modules_to_save):
                self.cls_layer_name = name
                break

        # to make sure classifier layer is trainable
        _set_trainable(self, adapter_name)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        peft_config = self.active_peft_config
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if not isinstance(peft_config, PromptLearningConfig):
            return self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                start_positions=start_positions,
                end_positions=end_positions,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        batch_size = input_ids.shape[0]
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "start_positions": start_positions,
                "end_positions": end_positions,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
        else:
            if kwargs.get("token_type_ids", None) is not None:
                kwargs["token_type_ids"] = torch.cat(
                    (
                        torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device),
                        kwargs["token_type_ids"],
                    ),
                    dim=1,
                ).long()
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

    def _prefix_tuning_forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size)
        fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
        kwargs.update(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "inputs_embeds": inputs_embeds,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
                "past_key_values": past_key_values,
            }
        )
        if "past_key_values" in fwd_params:
            return self.base_model(start_positions=start_positions, end_positions=end_positions, **kwargs)
        else:
            transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
            fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
            if "past_key_values" not in fwd_params:
                raise ValueError("Model does not support past key values which are required for prefix tuning.")
            outputs = transformer_backbone_name(**kwargs)
            sequence_output = outputs[0]
            if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
                sequence_output = self.base_model.dropout(sequence_output)
            logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1).contiguous()
            end_logits = end_logits.squeeze(-1).contiguous()

            total_loss = None
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions = start_positions.clamp(0, ignored_index)
                end_positions = end_positions.clamp(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)
                total_loss = (start_loss + end_loss) / 2

            if not return_dict:
                output = (start_logits, end_logits) + outputs[2:]
                return ((total_loss,) + output) if total_loss is not None else output

            return QuestionAnsweringModelOutput(
                loss=total_loss,
                start_logits=start_logits,
                end_logits=end_logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
