import torch
from torchvision import transforms
import drjit as dr
import lpips
from diffmat.optim.descriptor import TextureDescriptor

texture_descriptor = TextureDescriptor(device="cpu")
lpips_model = lpips.LPIPS(net="vgg")
for param in lpips_model.parameters():
    param.requires_grad = False


#函数工厂，根据名字不同返回不同的函数误差函数
def get_loss_fn(loss_type, device='cuda'):
    """
    Get a specific loss function by name.

    Args:
        loss_type: Type of loss (mse, mae, vgg, lpips, etc.)
        device: Device to create the loss on

    Returns:
        function: Loss function
    """
    if loss_type == "mse":#均方误差
        return lambda x, y: torch.mean((x - y) ** 2)

    elif loss_type == "mae":#平均绝对误差
        return lambda x, y: torch.mean(torch.abs(x - y))

    elif loss_type == "vgg":#用vgg网络提取图像特征再比较
        # Transform to reshape h, w, c to 1, c, h, w
        transform = transforms.Compose([
            transforms.Lambda(lambda x: x.permute(2, 0, 1).unsqueeze(0) if len(x.shape) == 3 else x)
        ])
        texture_descriptor.to(device)

        return lambda x, y: (
                texture_descriptor.evaluate(transform(x)) - texture_descriptor.evaluate(transform(y))).abs().mean()

    elif loss_type == "lpips":#感知损失，接近人类感知
        # Transform from [0,1] to [-1,1] and ensure correct shape
        transform = transforms.Compose([
            transforms.Lambda(lambda x: x.permute(2, 0, 1).unsqueeze(0) if len(x.shape) == 3 else x),
            transforms.Lambda(lambda x: x * 2 - 1)  # [0,1] -> [-1,1]
        ])

        lpips_model.to(device)

        return lambda x, y: lpips_model(transform(x), transform(y)).abs().mean()
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")


def get_loss_fn_drjit(loss_type):#返回drjit版本的损失函数
    """
    Get a specific loss function by name using drjit.

    Args:
        loss_type: Type of loss (mse, mae)
        device: Device to create the loss on

    Returns:
        function: Loss function
    """
    if loss_type == "mse":
        return lambda x, y: dr.mean((x - y) ** 2)
    elif loss_type == "mae":
        return lambda x, y: dr.mean(dr.abs(x - y))
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")


def batch_total_variation_loss(x):#约束图像整体的平滑性，抑制噪点或锯齿，对整批图像进行
    # Calculate horizontal variation
    h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]))
    # Calculate vertical variation
    w_tv = torch.mean(torch.abs(x[:, 1:, :, :] - x[:, :-1, :, :]))
    # Return the sum of both variations
    return h_tv + w_tv


def total_variation_loss(x):#约束图像整体的平滑性，抑制噪点或锯齿，对单张图像进行
    # Calculate horizontal differences
    h_diff = x[:, 1:, :] - x[:, :-1, :]
    # Calculate vertical differences
    v_diff = x[1:, :, :] - x[:-1, :, :]
    # Use absolute differences and compute mean
    h_loss = torch.mean(h_diff.abs())
    v_loss = torch.mean(v_diff.abs())
    return h_loss + v_loss


def total_variation_loss_drjit(x):#约束图像整体的平滑性，抑制噪点或锯齿，对整批图像进行，drjit版本
    # Calculate horizontal differences
    h_diff = x[:, 1:, :] - x[:, :-1, :]
    # Calculate vertical differences
    v_diff = x[1:, :, :] - x[:-1, :, :]
    # Use absolute differences and compute mean
    h_loss = dr.mean(dr.abs(h_diff))
    v_loss = dr.mean(dr.abs(v_diff))
    return h_loss + v_loss


def unbiased_l2_loss(img0, img1, ref):#对比两个图和参考图的一致性
    diff = (img0 - ref) * (img1 - ref)
    return torch.mean(torch.abs(diff))