# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose

from .._buffer_dict import BufferDict


class HOFTLayer(BaseTunerLayer):
    # List all names of layers that may contain adapter weights
    adapter_layer_names = ("hoft_U", "hoft_V")
    other_param_names = ()

    def __init__(self, base_layer: nn.Module, **kwargs):
        self.base_layer = base_layer
        self.r = {}
        self.hoft_dropout = nn.ModuleDict({})

        # For storing  scale
        self.hoft_U = nn.ParameterDict({})
        self.hoft_V = nn.ParameterDict({})
        self.hoft_magnitude_vector = nn.ParameterDict({})

        # Mark the weight as unmerged
        self._disable_adapters = False
        self.merged_adapters = []

        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            in_features, out_features = base_layer.in_features, base_layer.out_features
        else:
            in_features, out_features = None, None
            warnings.warn(f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning)

        self.in_features = in_features
        self.out_features = out_features
        self.kwargs = kwargs


    @property
    def merged(self) -> bool:
        return bool(self.merged_adapters)


    def update_layer(
        self,
        adapter_name: str,
        r: int,
        hoft_dropout: float,
        init_weights: str,
        use_shoft: bool
    ):
        if r <= 0:
            raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
        if r > min(self.in_features, self.out_features):
            raise ValueError(f"`r` should be less than min({self.in_features}, {self.out_features}), but the value passed is {r}")
        self.r[adapter_name] = r

        if hoft_dropout > 0.0:
            hoft_dropout_layer = nn.Dropout(p=hoft_dropout)
        else:
            hoft_dropout_layer = nn.Identity()

        self.hoft_dropout.update(nn.ModuleDict({adapter_name: hoft_dropout_layer}))


        # Add magnitude vector
        if use_shoft:
            self.hoft_magnitude_vector[adapter_name] = nn.Parameter(torch.ones(self.out_features), requires_grad=True)
            self.adapter_layer_names = self.adapter_layer_names[:] + ("hoft_magnitude_vector",)
 
        # Actual trainable parameters
        if init_weights not in ['normal', 'uniform', 'householder']:
            raise ValueError(f"`init_weights` should be a value in ['normal', 'uniform', 'householder'] but the value passed is {init_weights}")
    
        if init_weights == 'householder':
            U_vectors, V_vectors = self.householder_initialization(r)
        else:
            U_vectors = self.create_hoft_parameters(self.out_features, r, init_weights)
            V_vectors = self.create_hoft_parameters(self.in_features, r, init_weights)

        self.hoft_U[adapter_name] = nn.Parameter(U_vectors, requires_grad=True)
        self.hoft_V[adapter_name] = nn.Parameter(V_vectors, requires_grad=True)

        self._move_adapter_to_device_of_base_layer(adapter_name)
        self.set_adapter(self.active_adapters)


    def create_hoft_parameters(self, dim: int, amount: int, init_weights: str):

        vectors = torch.zeros((amount + 1) // 2, dim)

        if init_weights == "uniform":
            vectors = vectors.uniform_(-1./(dim**.5), 1./(dim**.5))
        elif init_weights == "normal":
            vectors = vectors.normal_(0, 1./(dim**.5))
        else:
            raise ValueError(f"Unsupported distribution: {init_weights}. Supported distributions are 'uniform' and 'normal'.")

        # Normalize householder vectors helps to reduce inverse approximation errors
        vectors = F.normalize(vectors)

        # Copy them by pairs in order to make identity 
        householder_vectors = torch.zeros(amount, dim)
        for i in range(0, amount, 2):
            householder_vectors[i] =  vectors[i // 2].clone()
            if(i + 1 < amount):
                householder_vectors[i + 1] =  vectors[i // 2].clone()

        # Transpose result and make tensor contiguous to correctly return weights
        return (householder_vectors.T).contiguous()


    def householder_initialization(self, r : int):

        weight = self.get_base_layer().weight.T
        dtype = weight.dtype
        device = weight.device

        if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
            raise TypeError(
                "Please initialize under float32, float16, or bfloat16. "
                "Subsequently, re-quantize the residual model to help minimize quantization errors."
            )
        
        # Obtain weight.T = V @ D @ U.T
        V, S, Uh = torch.linalg.svd(weight.to(torch.float32), full_matrices=True)
        D = torch.zeros((self.in_features, self.out_features)).to(dtype=dtype, device=device)
        D.diagonal().add_(S)

        #Obtain householder representation
        QU, RU = torch.linalg.qr(Uh.T)
        hu, tauu = torch.geqrf(QU)
        QV, RV = torch.linalg.qr(V)
        hv, tauv = torch.geqrf(QV)
        householder_U = hu * torch.sqrt(tauu)
        householder_V1 = hv * torch.sqrt(tauv)
        BU = torch.eye(self.out_features).to(dtype=dtype, device=device)
        BV = torch.eye(self.in_features).to(dtype=dtype, device=device)

        # Computation of UH, UB
        U1 = householder_U[:, :r]
        delta_2 = torch.triu(torch.matmul(U1.T, U1))
        delta_2.diagonal().div_(2)
        delta_3 = torch.linalg.solve_triangular(delta_2, U1, upper=True, left=False)
        inverse_UH = BU - torch.matmul(delta_3, U1.T) 

        hu, tauu = torch.geqrf(inverse_UH.T)
        U1 = hu * torch.sqrt(tauu)
        U1 = U1[:, :r]
        delta_2 = torch.triu(torch.matmul(U1.T, U1))
        delta_2.diagonal().div_(2)
        delta_3 = torch.linalg.solve_triangular(delta_2, U1, upper=True, left=False)
        UH = BU - torch.matmul(delta_3, U1.T) 

        U2 = householder_U[:, r:-1]
        delta_2 = torch.triu(torch.matmul(U2.T, U2))
        delta_2.diagonal().div_(2)
        delta_3 = torch.linalg.solve_triangular(delta_2, U2, upper=True, left=False)
        UB = (BU - torch.matmul(delta_3, U2.T)).to(dtype=dtype)

        #Computation of VH, VB
        V1 = householder_V1[:, :r]
        delta_2 = torch.triu(torch.matmul(V1.T, V1))
        delta_2.diagonal().div_(2)
        delta_3 = torch.linalg.solve_triangular(delta_2, V1, upper=True, left=False)
        VH = BV - torch.matmul(delta_3, V1.T) 

        V2 = householder_V1[:, r:-1]
        delta_2 = torch.triu(torch.matmul(V2.T, V2))
        delta_2.diagonal().div_(2)
        delta_3 = torch.linalg.solve_triangular(delta_2, V2, upper=True, left=False)
        VB = (BV - torch.matmul(delta_3, V2.T)).to(dtype=dtype)

        # Compute UH, VH errors due to inverse approximation during forward
        U = U1
        V = V1

        WU = torch.matmul(U.T, U)
        DU = WU.diagonal().reciprocal().mul(2)
        SU = DU * WU.triu(1) * DU
        SU.diagonal().sub_(DU)

        approx_UH = BU + torch.matmul(U, torch.matmul(SU, U.T))

        WV = torch.matmul(V.T, V)
        DV = WV.diagonal().reciprocal().mul(2)
        SV = DV * WV.triu(1) * DV
        SV.diagonal().sub_(DV)

        approx_VH = BV + torch.matmul(V, torch.matmul(SV, V.T))

        error_UH = torch.linalg.solve(approx_UH, UH, left=False).to(dtype=dtype)
        error_VH = torch.linalg.solve(approx_VH, VH).to(dtype=dtype)

        RU = RU.to(dtype=dtype)
        RV = RV.to(dtype=dtype)

        # Matrix core computation
        new_weight = error_VH @ VB @ RV @ D @ RU.T @ UB.T  @ error_UH
        self.get_base_layer().weight.data = new_weight.contiguous().T

        return U.contiguous().to(dtype), V.contiguous().to(dtype)

class Linear(nn.Linear, HOFTLayer):
    # Vera implemented in a dense layer
    def __init__(
        self,
        base_layer,
        adapter_name: str,
        r: int = 0,
        hoft_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        init_weights: str = "normal",
        use_shoft: bool = False,
        **kwargs,
    ) -> None:
        # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called
        super(nn.Linear, self).__init__()
        HOFTLayer.__init__(self, base_layer, **kwargs)
        self.fan_in_fan_out = fan_in_fan_out
        self.use_shoft = use_shoft

        self._active_adapter = adapter_name
        self.update_layer(adapter_name, r, hoft_dropout, init_weights, use_shoft)

    def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If True, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
            adapter_names (`List[str]`, *optional*):
                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
                to `None`.
        """
        adapter_names = check_adapters_to_merge(self, adapter_names)
        if not adapter_names:
            # no adapter to merge
            return

        for active_adapter in adapter_names:
            if active_adapter in self.hoft_U.keys():
                base_layer = self.get_base_layer()
                if safe_merge:
                    # Note that safe_merge will be slower than the normal merge
                    # because of the copy operation.
                    orig_weights = base_layer.weight.data.clone()

                    orig_weights += self.get_delta_weight(active_adapter)

                    if not torch.isfinite(orig_weights).all():
                        raise ValueError(
                            f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                        )

                    base_layer.weight.data = orig_weights
                else:
                    base_layer.weight.data += self.get_delta_weight(active_adapter)
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return

        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            if active_adapter in self.hoft_U.keys():
                self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)

    def get_delta_weight(self, adapter) -> torch.Tensor:
        """
        Compute the delta weight for the given adapter.

        Args:
            adapter (str):
                The name of the adapter for which the delta weight should be computed.
        """
        
        orig_weight = self.get_base_layer().weight.data.clone()

        device = orig_weight.device
        dtype = orig_weight.dtype

        U = self.hoft_U[adapter].to(dtype)
        V = self.hoft_V[adapter].to(dtype)

        # In case users wants to merge the adapter weights that are in
        # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
        # (b)float16 because some CPUs have slow bf16/fp16 matmuls.
        cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

        if cast_to_fp32:
            U = U.float()
            V = V.float()

        WV = torch.matmul(V.T, V)
        DV = WV.diagonal().reciprocal().mul(2)
        SV = DV * WV.triu(1) * DV
        SV.diagonal().sub_(DV)

        WU = torch.matmul(U.T, U)
        DU = WU.diagonal().reciprocal().mul(2)
        SU = DU * WU.triu(1) * DU
        SU.diagonal().sub_(DU)

        HV = torch.eye(self.in_features, device=device, dtype=dtype) + torch.matmul(V, torch.matmul(SV, V.T))

        HU =  torch.eye(self.out_features, device=device, dtype=dtype) + torch.matmul(U, torch.matmul(SU, U.T))
        
        
        if self.use_shoft:
            magnitude = self.hoft_magnitude_vector[adapter]
            if cast_to_fp32:
                magnitude = magnitude.float()
            
            output = torch.mm(torch.mm(torch.mm(HV, orig_weight.T), torch.diag_embed(magnitude)), HU)
        else:
            output = torch.mm(torch.mm(HV, orig_weight.T), HU)


        result = output.T - orig_weight

        if cast_to_fp32:
            result = result.to(dtype=dtype)

            self.hoft_U[adapter].data = U.to(dtype=dtype)
            self.hoft_V[adapter].data = V.to(dtype=dtype)
            if self.use_shoft:
                self.hoft_magnitude_vector[adapter].data = magnitude.to(dtype=dtype)

        return result

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        previous_dtype = x.dtype

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            for active_adapter in self.active_adapters:
                if active_adapter not in self.hoft_U.keys():
                    continue

                U = self.hoft_U[active_adapter]
                V = self.hoft_V[active_adapter]

                x = x.to(U.dtype)

                dropout = self.hoft_dropout[active_adapter]        
                result_V = torch.matmul(dropout(x), V)

                WV = torch.matmul(V.T, V)
                DV = WV.diagonal().reciprocal().mul(2)
                SV = DV * WV.triu(1) * DV
                SV.diagonal().sub_(DV)

                preresult = x + torch.matmul(result_V, torch.matmul(SV, V.T))
                result = torch.matmul(preresult, self.base_layer.weight.T.to(preresult.dtype))

                if self.use_shoft:
                   magnitude = self.hoft_magnitude_vector[active_adapter]
                   result = magnitude * result 
    
                result_U = torch.matmul(result, U)

                WU = torch.matmul(U.T, U)
                DU = WU.diagonal().reciprocal().mul(2)
                SU = DU * WU.triu(1) * DU
                SU.diagonal().sub_(DU)
                
                result = result + torch.matmul(result_U, torch.matmul(SU, U.T))

                if self.base_layer.bias is not None:
                    result = result + self.base_layer.bias

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "hoft." + rep
