# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import torch
from transformers.pytorch_utils import Conv1D
from itertools import chain
import re
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import (
    TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
    _freeze_adapter,
    _get_submodules,
    get_auto_gptq_quant_linear,
    get_quantization_config,
)
from peft.utils.integrations import gather_params_ctx
from peft.tuners.adalora.model import AdaLoraModel
from .layer import RoLoraLayer, SVDLinearRoLora
from peft.tuners.adalora.layer import AdaLoraLayer
import numpy as np

class RoLoraModel(AdaLoraModel):
    """
    Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
    https://openreview.net/forum?id=lq62uWRJjiY

    Args:
        model ([`transformers.PreTrainedModel`]): The model to be adapted.
        config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
        adapter_name (`str`): The name of the adapter, defaults to `"default"`.

    Returns:
        `torch.nn.Module`: The AdaLora model.

    Example::

        >>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig
        >>> config = AdaLoraConfig(
                peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
                lora_dropout=0.01,
            )
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default")

    **Attributes**:
        - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
        - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model.
    """

    # Note: don't redefine prefix here, it should be inherited from LoraModel

    def __init__(self, model, config, adapter_name):
        super().__init__(model, config, adapter_name)
    

    def _create_and_replace(
        self,
        lora_config,
        adapter_name,
        target,
        target_name,
        parent,
        current_key,
    ):
        r = lora_config.init_r
        if lora_config.rank_pattern:
            pattern_keys = list(lora_config.rank_pattern.keys())
            target_name_key = next(filter(lambda key:  key.startswith(current_key), pattern_keys),None)
            if target_name_key:
                r = len(lora_config.rank_pattern.get(target_name_key))

        kwargs = {
            "r": r,
            "lora_alpha": lora_config.lora_alpha,
            "lora_dropout": lora_config.lora_dropout,
            "fan_in_fan_out": lora_config.fan_in_fan_out,
            "init_lora_weights": lora_config.init_lora_weights,
            "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
            "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
        }
        if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available():
            raise ImportError(
                "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
                "You can install it with `pip install bitsandbytes`."
            )

        quantization_config = get_quantization_config(self.model, method="gptq")
        if quantization_config is not None:
            kwargs["gptq_quantization_config"] = quantization_config

        # If it is not an AdaLoraLayer, create a new module, else update it with new adapters
        if not isinstance(target, AdaLoraLayer):
            new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
            if adapter_name not in self.active_adapters:
                # adding an additional adapter: it is not automatically trainable
                new_module.requires_grad_(False)
            self._replace_module(parent, target_name, new_module, target)
        else:
            ranknum = target.ranknum[adapter_name]
            target.update_layer(
                adapter_name,
                r,
                lora_config.lora_alpha,
                lora_config.lora_dropout,
                lora_config.init_lora_weights,
            )
            if ranknum is not None and ranknum > 0:
                target.ranknum[adapter_name].data.fill_(float(ranknum))
                target.ranknum[adapter_name].requires_grad = False   


    @staticmethod
    def _create_new_module(lora_config, adapter_name, target, **kwargs):

        if isinstance(target, BaseTunerLayer):
            target_base_layer = target.get_base_layer()
        else:
            target_base_layer = target
        
        if isinstance(target_base_layer, torch.nn.Linear):
            if kwargs["fan_in_fan_out"]:
                warnings.warn(
                    "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                    "Setting fan_in_fan_out to False."
                )
                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
        elif isinstance(target_base_layer, Conv1D):
            if not kwargs["fan_in_fan_out"]:
                warnings.warn(
                    "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                    "Setting fan_in_fan_out to True."
                )
                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
        else:
            raise ValueError(
                f"Target module {target} is not supported. "
                f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
            )
        new_module = SVDLinearRoLora(target, adapter_name, **kwargs)

        return new_module
    
    def extend_modules(model, adapter_name, iteration, max_iter):
        step = (iteration)/(iteration+1.)
        lora_config = model.peft_config
        rank_pattern = lora_config[adapter_name].rank_pattern
        s, v, t, u = 0, 0, 0, 0
        keys = {}
        for name, rank_idx in rank_pattern.items():
            if isinstance(rank_idx, list):
                rank = sum(rank_idx)
            elif isinstance(rank_idx, torch.Tensor):
                rank_idx = rank_idx.view(-1)
                rank = rank_idx.sum().item()
            else:
                raise ValueError("Unexpected type of rank_idx")
            new_rank = max(rank, 1)
            old_rank = len(rank_idx)
            s += len(rank_idx)
            v += new_rank
            t += rank

            key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
            keys[key] = (new_rank, rank, old_rank)
            
        rank_pattern = {}
        for key in keys:
            _, target, _ = _get_submodules(model.model, key)
            new_rank = int(np.round(keys[key][0]*(s/v)*(1-step) + step*keys[key][2]))
            ranknum = target.ranknum[adapter_name]
            

            lora_E_weights = target.lora_E[adapter_name]
            lora_A_weights = target.lora_A[adapter_name]
            lora_B_weights = target.lora_B[adapter_name]
            a = lora_E_weights.shape[0]
            target.update_layer(
                adapter_name,
                new_rank,
                lora_config[adapter_name].lora_alpha,
                lora_config[adapter_name].lora_dropout,
                lora_config[adapter_name].init_lora_weights,
            )
            with torch.no_grad():
                if a > 0:
                    target.lora_E[adapter_name][:a, ].copy_(lora_E_weights)
                    target.lora_A[adapter_name][:a, :].copy_(lora_A_weights)
                    target.lora_B[adapter_name][:, :a].copy_(lora_B_weights)
                    target.ranknum[adapter_name].copy_(ranknum)
            rank_pattern[f"{key}.lora_E"] = [True]* (new_rank)
            # target.ranknum[adapter_name].data.fill_(float(new_rank))
            target.ranknum[adapter_name].requires_grad = False
        lora_config[adapter_name].rank_pattern = rank_pattern
        lora_config[adapter_name].target_r = min(
            int(np.ceil(lora_config[adapter_name].target_r * (max_iter+iteration+1.)/(max_iter+iteration))),
            lora_config[adapter_name].init_r-1)
        model.rankallocator._set_budget_scheduler(model)
    