import math

import numpy as np
import torch
import torch.nn.functional as F

loss_fn = torch.nn.L1Loss()

def adain_loss(q_style, k_style, v_style, q_ref, k_ref, v_ref, lambda_q=0.5):
    # 计算自注意力输出的差异
    self_atten_style = F.scaled_dot_product_attention(q_style, k_style, v_style)
    self_atten_ref = F.scaled_dot_product_attention(q_style, k_ref, v_ref)
    loss_self_atten = F.l1_loss(self_atten_style, self_atten_ref.detach())

    # 计算 q 的差异
    loss_q = F.l1_loss(q_style, q_ref.detach())

    # 总损失
    total_loss = loss_self_atten + lambda_q * loss_q
    # total_loss = lambda_q * loss_q
    return total_loss





