# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import torch
from transformers.pytorch_utils import Conv1D
from itertools import chain
import re
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import (
    TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
    _freeze_adapter,
    _get_submodules,
    get_auto_gptq_quant_linear,
    get_quantization_config,
)
from peft.utils.integrations import gather_params_ctx
from peft.tuners.lora.model import LoraModel
from .layer import SoraLayer, Linear
from peft.tuners.lora.layer import LoraLayer
import numpy as np

class SoraModel(LoraModel):
    """
        Generalize LoraModel
    """


    def __init__(self, model, config, adapter_name):
        super().__init__(model, config, adapter_name)
    
    
    
    def _create_and_replace(
        self,
        lora_config,
        adapter_name,
        target,
        target_name,
        parent,
        current_key,
    ):
        if current_key is None:
            raise ValueError("Current Key shouldn't be `None`")

        # Regexp matching - Find key which matches current target_name in patterns provided

        r = lora_config.r
        alpha = lora_config.lora_alpha

        if lora_config.rank_pattern:
            pattern_keys = list(lora_config.rank_pattern.keys())
            target_name_key = next(filter(lambda key:  current_key in key, pattern_keys),None)
            if target_name_key:
                r = len(lora_config.rank_pattern.get(target_name_key))

        kwargs = {
            "r": r,
            "lora_alpha": alpha,
            "lora_dropout": lora_config.lora_dropout,
            "fan_in_fan_out": lora_config.fan_in_fan_out,
            "init_lora_weights": lora_config.init_lora_weights,
        }

        if isinstance(target, SoraLayer):
            target.update_layer(
                adapter_name,
                r,
                lora_alpha=alpha,
                lora_dropout=lora_config.lora_dropout,
                init_lora_weights=lora_config.init_lora_weights
            )
        else:
            new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
            if adapter_name not in self.active_adapters:
                # adding an additional adapter: it is not automatically trainable
                new_module.requires_grad_(False)
            self._replace_module(parent, target_name, new_module, target)


    @staticmethod
    def _create_new_module(lora_config, adapter_name, target, **kwargs):
        new_module = None

        if isinstance(target, BaseTunerLayer):
            target_base_layer = target.get_base_layer()
        else:
            target_base_layer = target
        
        if isinstance(target_base_layer, torch.nn.Linear):
            if kwargs["fan_in_fan_out"]:
                warnings.warn(
                    "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                    "Setting fan_in_fan_out to False."
                )
                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
            new_module = Linear(target, adapter_name, **kwargs)
        else:
            raise ValueError(
                f"Target module {target} is not supported. "
                f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
            )
        

        return new_module

    def extend_modules(model, adapter_name, iteration):
        step = (iteration)/(iteration+1.)
        lora_config = model.peft_config
        s, v, t, u = 0, 0, 0, 0
        keys = {}
        rank_pattern = {}
        for n, p in model.named_parameters():
            if "lora_E" in n:
                rank = torch.count_nonzero(p).cpu()
                old_rank = p.shape[0]
                new_rank = max(rank, 1)
                key = ".".join(n.split(".")[:-2])
                s += old_rank
                v += new_rank
                t += rank

                keys[key] = (new_rank, rank, old_rank)
        for key in keys:
            _, target, _ = _get_submodules(model, key)
            new_rank = int(np.round(keys[key][0]*(s/v)*(1-step) + step*keys[key][2]))

            lora_E_weights = target.lora_E[adapter_name]
            idx = (lora_E_weights != 0).reshape(-1)
            lora_E_weights = lora_E_weights[idx]
            lora_A_weights = target.lora_A[adapter_name][idx]
            lora_B_weights = target.lora_B[adapter_name][:, idx]
            a = lora_E_weights.shape[0]
            target.update_layer(
                adapter_name,
                new_rank,
                lora_config[adapter_name].lora_alpha,
                lora_config[adapter_name].lora_dropout,
                lora_config[adapter_name].init_lora_weights,
            )
            with torch.no_grad():
                if a > 0:
                    target.lora_E[adapter_name][:a, ].copy_(lora_E_weights)
                    target.lora_A[adapter_name][:a, :].copy_(lora_A_weights)
                    target.lora_B[adapter_name][:, :a].copy_(lora_B_weights)
            rank_pattern[f"{key}.lora_E"] = [True]* (new_rank)
        lora_config[adapter_name].rank_pattern = rank_pattern
    