import torch
import torch.nn.functional as F

def sinkhorn_knopp(
    cost_matrix: torch.Tensor,
    a: torch.Tensor = None,
    b: torch.Tensor = None,
    reg: float = 0.1,
    numItermax: int = 50,
    stop_thresh: float = 1e-3,
    log_domain: bool = True,  # 使用对数域计算增强数值稳定性
    eps: float = 1e-8
) -> torch.Tensor:
    """
    可导的Sinkhorn算法实现（支持GPU）
    
    参数:
        cost_matrix: (m, n) 代价矩阵
        a: (m,) 输入分布的权重（None时默认为均匀分布）
        b: (n,) 目标分布的权重（None时默认为均匀分布）
        reg: 正则化系数（熵正则化的λ）
        num_iter: 最大迭代次数
        stop_thresh: 早停阈值
        log_domain: 是否使用对数域计算
        eps: 数值稳定的小常数
    
    返回:
        transport_plan: (m, n) 传输矩阵（可导）
    """
    m, n = cost_matrix.shape
    device = cost_matrix.device
    
    # 默认均匀分布
    if a is None:
        a = torch.ones(m, device=device) / m
    if b is None:
        b = torch.ones(n, device=device) / n
    
    if log_domain:
        # --- 对数域实现（数值更稳定）---
        log_a = torch.log(a + eps)
        log_b = torch.log(b + eps)
        log_K = -cost_matrix / reg  # K = exp(-C/reg)
        
        u = torch.zeros_like(a)
        v = torch.zeros_like(b)
        
        for _ in range(numItermax):
            u_prev = u.clone()
            
            # 更新u和v（对数域）
            v = log_b - torch.logsumexp(log_K + u.unsqueeze(1), dim=0)
            u = log_a - torch.logsumexp(log_K + v.unsqueeze(0), dim=1)
            
            # 早停检查
            if (u - u_prev).abs().max() < stop_thresh:
                break
                
        # 计算传输矩阵 P = diag(exp(u)) * K * diag(exp(v))
        transport_plan = torch.exp(u.unsqueeze(1) + log_K + v.unsqueeze(0))
        
    else:
        # --- 原始域实现（需要显式数值保护）---
        K = torch.exp(-cost_matrix / reg)
        K = K / K.sum()  # 初步归一化
        
        u = torch.ones_like(a) / m
        v = torch.ones_like(b) / n
        
        for _ in range(numItermax):
            u_prev = u.clone()
            
            v = b / (K.T @ u + eps)
            u = a / (K @ v + eps)
            
            if (u - u_prev).abs().max() < stop_thresh:
                break
                
        transport_plan = u.unsqueeze(1) * K * v.unsqueeze(0)
    
    return transport_plan

def sinkhorn_knopp_unbalanced(
    cost_matrix: torch.Tensor,
    a: torch.Tensor = None,
    b: torch.Tensor = None,
    reg: float = 0.1,
    reg_m: float = 0.1,
    numItermax: int = 10,
    stop_thresh: float = 1e-3,
    log_domain: bool = True,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    可导的Unbalanced Sinkhorn算法（支持GPU + log-domain）
    
    参数:
        cost_matrix: (m, n) torch.Tensor, 代价矩阵
        a: (m,) 输入分布（可不归一）
        b: (n,) 输出分布（可不归一）
        reg: 熵正则项 (ε)
        reg_m: 质量惩罚项 (τ)
        numItermax: 最大迭代次数
        stop_thresh: 早停阈值
        log_domain: 是否使用对数域计算（建议开启）
        eps: 稳定性小常数
    
    返回:
        transport_plan: (m, n) torch.Tensor
    """
    m, n = cost_matrix.shape
    device = cost_matrix.device

    if a is None:
        a = torch.ones(m, device=device) / m
    if b is None:
        b = torch.ones(n, device=device) / n

    # 平衡因子
    tau = reg_m / (reg_m + reg)
    
    if log_domain:
        log_K = -cost_matrix / reg
        u = torch.zeros_like(a)
        v = torch.zeros_like(b)

        log_a = torch.log(a + eps)
        log_b = torch.log(b + eps)

        for _ in range(numItermax):
            u_prev = u.clone()

            v = tau * (log_b - torch.logsumexp(log_K + u.unsqueeze(1), dim=0))
            u = tau * (log_a - torch.logsumexp(log_K + v.unsqueeze(0), dim=1))

            if (u - u_prev).abs().max() < stop_thresh:
                break

        # transport plan: diag(exp(u)) * K * diag(exp(v))
        transport_plan = torch.exp(u.unsqueeze(1) + log_K + v.unsqueeze(0))
    else:
        K = torch.exp(-cost_matrix / reg)
        u = torch.ones_like(a)
        v = torch.ones_like(b)

        for _ in range(numItermax):
            u_prev = u.clone()

            Kv = K @ v + eps
            u = (a / Kv) ** tau

            KTu = K.T @ u + eps
            v = (b / KTu) ** tau

            if (u - u_prev).abs().max() < stop_thresh:
                break

        transport_plan = u.unsqueeze(1) * K * v.unsqueeze(0)

    return transport_plan



# ===== 使用示例 =====
if __name__ == "__main__":
    # 生成随机代价矩阵（可导）
    torch.manual_seed(42)
    cost_matrix = torch.rand(3, 4, requires_grad=True).cuda()
    cost_matrix.retain_grad()  # 显式保留梯度
    # 运行Sinkhorn
    # P = sinkhorn_knopp(cost_matrix, reg=0.1)
    P = sinkhorn_knopp_unbalanced(cost_matrix, reg=0.1, reg_m=0.5)
    
    print("Transport Plan:\n", P)
    print("Grad enabled:", P.requires_grad)  # 应为True
    
    # 测试梯度
    loss = P.sum()
    loss.backward()
    print("cost_matrix.grad:\n", cost_matrix.grad)  # 应非空