import torch
from torch.autograd import Function
from torch.amp import autocast
from timm.utils import NativeScaler
from torch.amp import custom_fwd, custom_bwd
from torch import nn

class CustomLinearFunction(Function):
    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(ctx, input, lora_A, lora_B):
        ctx.save_for_backward(input, lora_A, lora_B)
        # 并行计算低秩和滤波的输出
        result = (input @ lora_A.transpose(0, 1)) @ lora_B.transpose(0, 1)
        return result

    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_output):
        x, lora_A, lora_B = ctx.saved_tensors

        grad_input = grad_A = grad_B = None
        x_reshape = x
        grad_reshape = grad_output

        # shape > 2说明第一个维度是batchsize，为了避免误差的引入，将batchsize维度展开
        if x.dim() == 3:
            x_reshape = x.view(x.shape[0] * x.shape[1], x.shape[2])
        if grad_output.dim() == 3:
            grad_reshape = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])

        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ lora_B @ lora_A
        if ctx.needs_input_grad[1]:
            grad_A_temp = torch.matmul(grad_reshape, lora_B)
            grad_A = grad_A_temp.t() @ x_reshape
        if ctx.needs_input_grad[2]:
            grad_B_temp = torch.matmul(x_reshape, lora_A.t())
            grad_B = torch.matmul(grad_reshape.t(), grad_B_temp)

        return grad_input, grad_A, grad_B

torch.manual_seed(42)
torch.set_printoptions(precision=7)


class test_model(nn.Module):
    def __init__(self, input_size=10, rank=5):
        super(test_model, self).__init__()
        self.rank = rank
        self.size_in = input_size
        self.size_out = input_size
        self.linear = nn.Linear(10,10,bias=False)
        self.lora_A = nn.Parameter(torch.empty(self.rank, self.size_in), requires_grad=True)
        self.lora_B = nn.Parameter(torch.empty(self.size_out, self.rank), requires_grad=True)
        self.weight = nn.Parameter(torch.empty(self.size_out, self.rank), requires_grad=True)

        nn.init.kaiming_uniform_(self.linear.weight)
        nn.init.kaiming_uniform_(self.lora_A)
        nn.init.kaiming_uniform_(self.lora_B)

    def forward(self, x):
        result1 = self.linear(x)
        result2 = CustomLinearFunction.apply(x, self.lora_A, self.lora_B)
        result = result1 + result2
        return result



test_model = test_model()
# test_model.to("cuda:0")
input = torch.rand(2, 3, 10, requires_grad=True)       # [batch, n, d]
test_model.to("cuda")
device = torch.device("cuda:0")
input = input.to(device)


# 前向计算
loss_scaler = torch.amp.GradScaler()

with autocast(device_type="cuda"):
    output = test_model(input)
    loss = output.mean()

loss_scaler.scale(loss).backward(create_graph=False)

auto_grad_A = test_model.lora_A.grad
auto_grad_B = test_model.lora_B.grad

print(auto_grad_A)
# Lora_A.grad = None
# Lora_B.grad = None
# random1.grad = None
# random2.grad = None
# x.grad = None

#
# with autocast(device_type="cuda"):
#     temp2 = x @ random1
#     # temp3 = x @ random2
#     temp4 = CustomLinearFunction.apply(x.half(), Lora_A, Lora_B)
#     temp5 = temp2 + temp4
#     result2 = temp5.sum()             # 标量损失
#
# loss_scaler.scale(result2).backward(create_graph=False)





#
# output = x + CustomLinearFunction.apply(x,Lora_A, Lora_B)
# # cus = CustomLinearFunction.apply(x,Lora_A, Lora_B)
# # output = output + cus
# # cus_grad = cus.grad
# result1 = output.sum()
# result1.backward()

# print(Lora_A.grad)

# # 手动计算梯度
# grad_result = torch.ones_like(temp2)  # 假设上游梯度为1
# grad_result_reshape = grad_result.view(grad_result.shape[0] * grad_result.shape[1], grad_result.shape[2])
# x_reshape = x.view(x.shape[0] * x.shape[1], x.shape[2])
#
# temp_custom_x = x_reshape @ Lora_A.t()
# grad_B_custom = grad_result_reshape.t() @ temp_custom_x

# ccc = grad_result.transpose(-2,-1) @ temp1  # [10, 3] @ [3, 5] = [10, 5]
# grad_temp1_manual = grad_result @ Lora_B      # [3, 10] @ [10, 5] = [3, 5]
# grad_Lora_A_manual = grad_temp1_manual.transpose(-1,-2) @ x  # [5, 3] @ [3, 10] = [5, 10]
# grad_x_manual = grad_result @ Lora_B @ Lora_A
#
# print(x.grad)
# 比较自动微分和手动计算的梯度
# print("Lora_A梯度误差:", torch.norm(Lora_A.grad - auto_grad_A).item())
# print("Lora_B梯度误差:", torch.norm(Lora_B.grad- auto_grad_B).item())
# print("X梯度误差:", torch.norm(x.grad - auto_grad_x).item())
# print("Loss误差:", torch.norm(temp2_grad - cus_grad).item())