# Copyright (c) 2021 microsoft
#               2023 Alan (alanfangemail@gmail.com)
#  -----------------------------------------------------------------------------
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for
#  license information.
#  -----------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F

import math
from typing import List


class LoRALayer():

    def __init__(
        self,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = self.identity
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights

    def identity(self, x):
        return x

    def __repr__(self):
        return f"LoRALayer(r={self.r})"


class Embedding(nn.Embedding, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 r: int = 0,
                 lora_alpha: int = 1,
                 merge_weights: bool = True,
                 **kwargs):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoRALayer.__init__(self,
                           r=r,
                           lora_alpha=lora_alpha,
                           lora_dropout=0,
                           merge_weights=merge_weights)
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(
                self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.zeros_(self.lora_A)
            nn.init.normal_(self.lora_B)

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    temp = (self.lora_B @ self.lora_A).transpose(0, 1)
                    self.weight.data -= temp * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    temp = (self.lora_B @ self.lora_A).transpose(0, 1)
                    self.weight.data += temp * self.scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            result = nn.Embedding.forward(self, x)
            after_A = F.embedding(x, self.lora_A.transpose(0, 1),
                                  self.padding_idx, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq,
                                  self.sparse)
            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return nn.Embedding.forward(self, x)


class Linear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    _BASE_MODULE = nn.Linear
    
    def __init__(
            self,
            in_features: int,
            out_features: int,
            r: int = 0,
            lora_alpha: int = 1,
            lora_dropout: float = 0.,
            fan_in_fan_out: bool = False,
            # Set this to True if the layer to replace stores weight like (fan_in,
            #                                                              fan_out)
            merge_weights: bool = True,
            **kwargs):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self,
                           r=r,
                           lora_alpha=lora_alpha,
                           lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros(
                (out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def T(self, w):
        return w.transpose(0, 1) if self.fan_in_fan_out else w

    def train(self, mode: bool = True):
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    temp = self.T(self.lora_B @ self.lora_A)
                    self.weight.data -= temp * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    temp = self.T(self.lora_B @ self.lora_A)
                    self.weight.data += temp * self.scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            result = self.forward_linear(x)
            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1)
                       @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return self.forward_linear(x)
    
    # NOTE(jiahong): for consistency with Whisper model
    def forward_linear(self, x: torch.Tensor):
        return F.linear(
            x,
            self.T(self.weight.to(x.dtype)),
            None if self.bias is None else self.bias.to(x.dtype),
        )

    @classmethod
    def from_base(cls, mod, **kwargs):
        assert isinstance(mod, cls._BASE_MODULE), f"Expect {cls._BASE_MODULE}, got {type(mod)}"
        lora_module = cls(mod.in_features, mod.out_features, **kwargs)
        lora_module.load_state_dict(mod.state_dict(), strict=False)
        return lora_module

    def __repr__(self):
        return f"LoRALinear(\n" + \
               f"  {nn.Linear.__repr__(self)}\n" + \
               f"  {LoRALayer.__repr__(self)}\n" + \
               f")"


class MergedLinear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 r: int = 0,
                 lora_alpha: int = 1,
                 lora_dropout: float = 0.,
                 enable_lora: List[bool] = None,
                 fan_in_fan_out: bool = False,
                 merge_weights: bool = True,
                 **kwargs):
        if enable_lora is None:
            enable_lora = [False]
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self,
                           r=r,
                           lora_alpha=lora_alpha,
                           lora_dropout=lora_dropout,
                           merge_weights=merge_weights)
        assert out_features % len(enable_lora) == 0, \
            'The length of enable_lora must divide out_features'
        self.enable_lora = enable_lora
        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0 and any(enable_lora):
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * sum(enable_lora), in_features)))
            self.lora_B = nn.Parameter(
                self.weight.new_zeros(
                    (out_features // len(enable_lora) * sum(enable_lora), r)))
            # weights for Conv1D with groups=sum(enable_lora)
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            # Compute the indices
            self.lora_ind = self.weight.new_zeros(
                (out_features, ), dtype=torch.bool).view(len(enable_lora), -1)
            self.lora_ind[enable_lora, :] = True
            self.lora_ind = self.lora_ind.view(-1)
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def zero_pad(self, x):
        result = x.new_zeros((len(self.lora_ind), *x.size()[1:]))
        result[self.lora_ind] = x
        return result

    def T(self, w):
        return w.transpose(0, 1) if self.fan_in_fan_out else w

    def merge_AB(self):
        delta_w = F.conv1d(self.lora_A.unsqueeze(0),
                           self.lora_B.unsqueeze(-1),
                           groups=sum(self.enable_lora)).squeeze(0)
        return self.T(delta_w)

    def train(self, mode: bool = True):
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0 and any(self.enable_lora):
                    self.weight.data -= self.merge_AB() * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0 and any(self.enable_lora):
                    self.weight.data += self.merge_AB() * self.scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        if self.merged:
            return F.linear(x, self.T(self.weight), bias=self.bias)
        else:
            result = F.linear(x, self.T(self.weight), bias=self.bias)
            if self.r > 0:
                temp = self.T(self.merge_AB().T)
                result += self.lora_dropout(x) @ temp * self.scaling
            return result


class ConvLoRA(nn.Module, LoRALayer):

    def __init__(self,
                 conv_module,
                 in_channels,
                 out_channels,
                 kernel_size,
                 r=0,
                 lora_alpha=1,
                 lora_dropout=0.,
                 merge_weights=True,
                 **kwargs):
        super(ConvLoRA, self).__init__()
        self.conv = conv_module(in_channels, out_channels, kernel_size,
                                **kwargs)
        LoRALayer.__init__(self,
                           r=r,
                           lora_alpha=lora_alpha,
                           lora_dropout=lora_dropout,
                           merge_weights=merge_weights)
        assert isinstance(kernel_size, int)
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(
                self.conv.weight.new_zeros(
                    (r * kernel_size, in_channels * kernel_size)))
            self.lora_B = nn.Parameter(
                self.conv.weight.new_zeros(
                    (out_channels // self.conv.groups * kernel_size,
                     r * kernel_size)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.conv.weight.requires_grad = False
        self.reset_parameters()
        self.merged = False

    def reset_parameters(self):
        self.conv.reset_parameters()
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode=True):
        super(ConvLoRA, self).train(mode)
        if mode:
            if self.merge_weights and self.merged:
                if self.r > 0:
                    # Make sure that the weights are not merged
                    self.conv.weight.data -= (self.lora_B @ self.lora_A).view(
                        self.conv.weight.shape) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                if self.r > 0:
                    # Merge the weights and mark it
                    self.conv.weight.data += (self.lora_B @ self.lora_A).view(
                        self.conv.weight.shape) * self.scaling
                self.merged = True

    def forward(self, x):
        if self.r > 0 and not self.merged:
            return self.conv._conv_forward(
                x, self.conv.weight +
                (self.lora_B @ self.lora_A).view(self.conv.weight.shape) *
                self.scaling, self.conv.bias)
        return self.conv(x)


class Conv2d(ConvLoRA):

    def __init__(self, *args, **kwargs):
        super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)


class Conv1d(ConvLoRA):

    def __init__(self, *args, **kwargs):
        super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)


# Can Extend to other ones like this
class Conv3d(ConvLoRA):

    def __init__(self, *args, **kwargs):
        super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)