# import math
# import torch
# import torch.nn as nn
# from peft.tuners.lora import LoraLayer
#
# class MyLoRALayer(LoraLayer):
#     def __init__(self, r, lora_alpha, lora_dropout, merge_weights=True):
#         super().__init__(r, lora_alpha, lora_dropout, merge_weights)
#         self.lora_C = nn.Parameter(torch.zeros((r, self.d)))
#         self.lora_D = nn.Parameter(torch.zeros((self.d, r)))
#         self.scaling_A =lora_alpha / r
#         self.scaling_C = lora_alpha / r
#         self.lora_dropout = nn.Dropout(p=lora_dropout)
#         self.reset_parameters()
#
#     def reset_parameters(self):
#         nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
#         nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5))
#         nn.init.kaiming_uniform_(self.lora_C, a=math.sqrt(5))
#         nn.init.kaiming_uniform_(self.lora_D, a=math.sqrt(5))
#
#     def forward(self, x):
#         if self.merge_weights and self.merged:
#             lora_term_1 = (self.lora_B @ self.lora_A) * self.scaling_A
#             lora_term_2 = (self.lora_D @ self.lora_C) * self.scaling_C
#             return lora_term_1 + lora_term_2 + x
#         else:
#             term_1 = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
#             term_2 = self.lora_dropout(x) @ self.lora_D.T @ self.lora_C.T
#             return (term_1 * self.scaling_A + term_2 * self.scaling_C) + x
#
# if __name__=='__main___':
#     model = YourPretrainedModel()
#     for name, module in model.named_modules():
#         if isinstance(module, nn.Linear):
#             lora_layer = MyLoRALayer(r=16, lora_alpha=16, lora_dropout=0.1)
#             module.weight = lora_layer.merge_weights()

import torch
import math
import torch.nn as nn
from loralib import Linear

class MyLoRALayer(Linear):
    def __init__(self, in_features, out_features, r=4, lora_alpha=None, lora_dropout=None, fan_in_fan_out=False,
                 **kwargs):
        super().__init__(in_features, out_features, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                         fan_in_fan_out=fan_in_fan_out, **kwargs)

        # Initialize the LoRA weight
        self.lora_C = nn.Parameter(self.weight.new_zeros((4, in_features)), requires_grad=True)
        self.lora_D = nn.Parameter(self.weight.new_zeros((in_features, 4)), requires_grad=True)
        self.scaling = lora_alpha / r if lora_alpha is not None else 1.0

        self.reset_new_parameters()

    def reset_new_parameters(self):
        super().reset_parameters()
        nn.init.zeros_(self.lora_C)
        nn.init.zeros_(self.lora_D)


    def forward(self, x):
        if self.fan_in_fan_out:
            print('fffffffffffffffff')
            print('fffffffffffffffff')
            print('fffffffffffffffff')
            print('fffffffffffffffff')
            result = torch.matmul(torch.matmul(x, self.lora_A), self.lora_B)
            result += torch.matmul(torch.matmul(x, self.lora_C), self.lora_D)
            result *= self.scaling
            result += x
        else:
            print('hhhhhhhhhhhhhhhhh')
            print('hhhhhhhhhhhhhhhhh')
            print('hhhhhhhhhhhhhhhhh')
            print('hhhhhhhhhhhhhhhhh')
            result = torch.matmul(torch.matmul(x, self.lora_A), self.lora_B)
            result += torch.matmul(torch.matmul(x, self.lora_C.T), self.lora_D.T)
            result *= self.scaling
            result += x
        return result