# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Optional

import torch

from peft.tuners.xlora.model import XLoraModel

from .config import PeftConfig
from .mixed_model import PeftMixedModel
from .peft_model import (
    PeftModel,
    PeftModelForCausalLM,
    PeftModelForFeatureExtraction,
    PeftModelForQuestionAnswering,
    PeftModelForSeq2SeqLM,
    PeftModelForSequenceClassification,
    PeftModelForTokenClassification,
)
from .tuners import (
    AdaLoraConfig,
    AdaLoraModel,
    AdaptionPromptConfig,
    BOFTConfig,
    BOFTModel,
    FourierFTConfig,
    FourierFTModel,
    HRAConfig,
    HRAModel,
    IA3Config,
    IA3Model,
    LNTuningConfig,
    LNTuningModel,
    LoHaConfig,
    LoHaModel,
    LoKrConfig,
    LoKrModel,
    LoraConfig,
    LoraModel,
    MultitaskPromptTuningConfig,
    OFTConfig,
    OFTModel,
    PolyConfig,
    PolyModel,
    PrefixTuningConfig,
    PromptEncoderConfig,
    PromptTuningConfig,
    VBLoRAConfig,
    VBLoRAModel,
    VeraConfig,
    VeraModel,
    XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner
from .utils import _prepare_prompt_learning_config

if TYPE_CHECKING:
    from transformers import PreTrainedModel

MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
    "SEQ_CLS": PeftModelForSequenceClassification,
    "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
    "CAUSAL_LM": PeftModelForCausalLM,
    "TOKEN_CLS": PeftModelForTokenClassification,
    "QUESTION_ANS": PeftModelForQuestionAnswering,
    "FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
}

PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
    "ADAPTION_PROMPT": AdaptionPromptConfig,
    "PROMPT_TUNING": PromptTuningConfig,
    "PREFIX_TUNING": PrefixTuningConfig,
    "P_TUNING": PromptEncoderConfig,
    "LORA": LoraConfig,
    "LOHA": LoHaConfig,
    "LORAPLUS": LoraConfig,
    "LOKR": LoKrConfig,
    "ADALORA": AdaLoraConfig,
    "BOFT": BOFTConfig,
    "IA3": IA3Config,
    "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
    "OFT": OFTConfig,
    "POLY": PolyConfig,
    "LN_TUNING": LNTuningConfig,
    "VERA": VeraConfig,
    "FOURIERFT": FourierFTConfig,
    "XLORA": XLoraConfig,
    "HRA": HRAConfig,
    "VBLORA": VBLoRAConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
    "LORA": LoraModel,
    "LOHA": LoHaModel,
    "LOKR": LoKrModel,
    "ADALORA": AdaLoraModel,
    "BOFT": BOFTModel,
    "IA3": IA3Model,
    "OFT": OFTModel,
    "POLY": PolyModel,
    "LN_TUNING": LNTuningModel,
    "VERA": VeraModel,
    "FOURIERFT": FourierFTModel,
    "XLORA": XLoraModel,
    "HRA": HRAModel,
    "VBLORA": VBLoRAModel,
}


def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
    """
    Returns a Peft config object from a dictionary.

    Args:
        config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
    """

    return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)


def get_peft_model(
        model: PreTrainedModel,
        peft_config: PeftConfig,
        adapter_name: str = "default",
        mixed: bool = False,
        autocast_adapter_dtype: bool = True,
        revision: Optional[str] = None,
) -> PeftModel | PeftMixedModel:
    """
    Returns a Peft model object from a model and a config.

    Args:
        model ([`transformers.PreTrainedModel`]):
            Model to be wrapped.
        peft_config ([`PeftConfig`]):
            Configuration object containing the parameters of the Peft model.
        adapter_name (`str`, `optional`, defaults to `"default"`):
            The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
        mixed (`bool`, `optional`, defaults to `False`):
            Whether to allow mixing different (compatible) adapter types.
        autocast_adapter_dtype (`bool`, *optional*):
            Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
            using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect
            select PEFT tuners.
        revision (`str`, `optional`, defaults to `main`):
            The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
            the base model
    """
    model_config = BaseTuner.get_model_config(model)
    old_name = peft_config.base_model_name_or_path
    new_name = model.__dict__.get("name_or_path", None)
    peft_config.base_model_name_or_path = new_name

    if (old_name is not None) and (old_name != new_name):
        warnings.warn(
            f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
            "Please ensure that the correct base model is loaded when loading this checkpoint."
        )

    if revision is not None:
        if peft_config.revision is not None and peft_config.revision != revision:
            warnings.warn(
                f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
            )
        peft_config.revision = revision

    if mixed:
        # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it
        return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

    if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
        return PeftModel(model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)

    if peft_config.is_prompt_learning:
        peft_config = _prepare_prompt_learning_config(peft_config, model_config)
    return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
        model, peft_config, adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
    )


def inject_adapter_in_model(
        peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
) -> torch.nn.Module:
    r"""
    A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
    methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
    calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.

    Args:
        peft_config (`PeftConfig`):
            Configuration object containing the parameters of the Peft model.
        model (`torch.nn.Module`):
            The input model where the adapter will be injected.
        adapter_name (`str`, `optional`, defaults to `"default"`):
            The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
        low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
            Create empty adapter weights on meta device. Useful to speed up the loading process.
    """
    if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
        raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")

    if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys():
        raise ValueError(
            f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`."
        )

    tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]

    # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
    peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

    return peft_model.model
