import torch
import torch.nn as nn
import geomloss  # 用于计算 Sinkhorn 散度

class SinkhornLoss(nn.Module):
    def __init__(self, epsilon=0.1, scaling=0.95, max_iter=100):
        """Initialize the Sinkhorn loss function.
        
        Args:
            epsilon (float): Regularization parameter for Sinkhorn algorithm.
            scaling (float): Scaling factor for Sinkhorn algorithm.
            max_iter (int): Maximum number of iterations.
        """
        super().__init__()
        self.epsilon = epsilon
        self.scaling = scaling
        self.max_iter = max_iter
        self.sinkhorn = geomloss.SamplesLoss(
            loss='sinkhorn',
            p=2,
            blur=epsilon,
            scaling=scaling,
            backend='tensorized'
        )

    def forward(self, pred, target, alpha=1):
        """Calculate Sinkhorn loss for the last point in the sliding window.
        
        Args:
            pred (torch.Tensor): Model predictions with shape (batch_size, lookback_window + 1, dim).
            target (torch.Tensor): Target values with shape (batch_size, lookback_window + 1, dim).
            alpha (float): Weight coefficient for mixed loss.
        
        Returns:
            torch.Tensor: The computed Sinkhorn loss value.
        """
        # 只取最后一个预测点进行损失计算
        pred_last = pred[:, -1, :]      # (batch_size, dim)
        target_last = target[:, -1, :]  # (batch_size, dim)
        
        # 计算最后一个时间步的 Sinkhorn 散度
        sinkhorn_loss = self.sinkhorn(pred_last, target_last)
        # print(sinkhorn_loss)
        return sinkhorn_loss




class WassersteinLoss(nn.Module):
    def __init__(self, num_bins=100, epsilon=0.1):
        """Initialize the Wasserstein loss function.
        
        Args:
            num_bins (int): Number of histogram bins.
            epsilon (float): Regularization parameter for optimal transport.
        """
        super().__init__()
        self.num_bins = num_bins
        self.epsilon = epsilon
        self.mse = nn.MSELoss()

    def compute_histogram(self, x):
        """Compute histogram of input tensor.
        
        Args:
            x (torch.Tensor): Input tensor with shape (batch_size, dim).

        Returns:
            tuple:
                torch.Tensor: Normalized histogram.
                torch.Tensor: Bin edges.
        """
        # 将数据展平到一维
        x_flat = x.flatten()
        
        # 计算数据范围
        min_val = x_flat.min()
        max_val = x_flat.max()
        
        # 创建直方图
        hist = torch.histc(x_flat, bins=self.num_bins, 
                          min=min_val, max=max_val)
        
        # 归一化直方图
        hist = hist / hist.sum()
        
        # 计算箱的边界
        bins = torch.linspace(min_val, max_val, self.num_bins + 1)
        
        return hist, bins

    def forward(self, pred, target, alpha=0.5):
        """Calculate mixed loss between predictions and targets.
        
        Args:
            pred (torch.Tensor): Model predictions with shape (batch_size, seq_len, dim).
            target (torch.Tensor): Target values with shape (batch_size, seq_len, dim).
            alpha (float): Weight coefficient for mixed loss.

        Returns:
            torch.Tensor: The computed mixed loss value.
        """
        # 获取最后一个时间步
        pred_last = pred[:, -1, :]    # (batch_size, dim)
        target_last = target[:, -1, :] # (batch_size, dim)
        
        # 计算直方图
        pred_hist, pred_bins = self.compute_histogram(pred_last)
        target_hist, target_bins = self.compute_histogram(target_last)
        
        # 计算距离矩阵
        bin_centers = (pred_bins[:-1] + pred_bins[1:]) / 2
        M = torch.cdist(bin_centers.unsqueeze(1), 
                       bin_centers.unsqueeze(1), p=2)
        
        # 使用 Sinkhorn 算法计算 Wasserstein 距离
        P = self.sinkhorn_knopp(pred_hist, target_hist, M, 
                               epsilon=self.epsilon)
        wasserstein_loss = torch.sum(P * M)
        
        # 计算 MSE 损失
        mse_loss = self.mse(pred, target)
        
        # 混合损失
        total_loss = alpha * wasserstein_loss + (1 - alpha) * mse_loss
        
        return total_loss
    def sinkhorn_knopp(self, a, b, M, epsilon=0.1, max_iter=100):
        """Solve optimal transport problem using Sinkhorn-Knopp algorithm.
        
        Args:
            a (torch.Tensor): Source distribution.
            b (torch.Tensor): Target distribution.
            M (torch.Tensor): Cost matrix.
            epsilon (float): Regularization parameter.
            max_iter (int): Maximum number of iterations.

        Returns:
            torch.Tensor: Optimal transport plan.
        """
        K = torch.exp(-M / epsilon)
        u = torch.ones_like(a)
        v = torch.ones_like(b)
        
        for _ in range(max_iter):
            u = a / (K @ v)
            v = b / (K.T @ u)
        
        P = torch.diag(u) @ K @ torch.diag(v)
        return P


class SequenceSinkhornLoss(nn.Module):
    def __init__(self, epsilon=0.1, scaling=0.95, max_iter=100):
        """Initialize the sequence Sinkhorn loss function.
        
        Args:
            epsilon (float): Regularization parameter for Sinkhorn algorithm.
            scaling (float): Scaling factor for Sinkhorn algorithm.
            max_iter (int): Maximum number of iterations.
        """
        super().__init__()
        self.epsilon = epsilon
        self.scaling = scaling
        self.max_iter = max_iter
        # 使用 geomloss 中的 SamplesLoss 来计算 Sinkhorn 散度
        self.sinkhorn = geomloss.SamplesLoss(
            loss='sinkhorn',
            p=2,  # 2-Wasserstein 距离
            blur=epsilon,
            scaling=scaling,
            backend='tensorized'
        )
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        """Calculate sum of Sinkhorn divergences for each timestep in the sequence.
        
        Args:
            pred (torch.Tensor): Model predictions with shape (batch_size, seq_len, dim).
            target (torch.Tensor): Target values with shape (batch_size, seq_len, dim).

        Returns:
            torch.Tensor: The computed average Sinkhorn loss value.
        """
        batch_size, output_len, dim = pred.shape
        total_sinkhorn_loss = 0.0
        
        # 对每个时间步计算 Sinkhorn 散度
        for t in range(output_len):
            pred_t = pred[:, t, :]      # (batch_size, dim)
            target_t = target[:, t, :]  # (batch_size, dim)
            
            # 计算当前时间步的 Sinkhorn 散度
            sinkhorn_loss_t = self.sinkhorn(pred_t, target_t)
            total_sinkhorn_loss += sinkhorn_loss_t
        
        # 计算平均 Sinkhorn 损失
        avg_sinkhorn_loss = total_sinkhorn_loss / output_len
        

        
        # 混合损失
        total_loss = avg_sinkhorn_loss 
        
        return total_loss





# class SequenceSinkhornLoss(nn.Module):
#     def __init__(self, epsilon=0.1, scaling=0.95, max_iter=100):
#         """
#         初始化序列 Sinkhorn 损失函数。
        
#         Args:
#             epsilon: Sinkhorn 正则化参数
#             scaling: Sinkhorn 缩放参数
#             max_iter: 最大迭代次数
#         """
#         super().__init__()
#         self.epsilon = epsilon
#         self.scaling = scaling
#         self.max_iter = max_iter
#         # 使用 geomloss 中的 SamplesLoss 来计算 Sinkhorn 散度
#         self.sinkhorn = geomloss.SamplesLoss(
#             loss='sinkhorn',
#             p=2,  # 2-Wasserstein 距离
#             blur=epsilon,
#             scaling=scaling,
#             backend='tensorized'
#         )

#     def forward(self, pred, target):
#         """
#         计算序列中每个时间步的分布之间的 Sinkhorn 散度总和。
#         使用矩阵运算来并行计算不同时间步的损失。
        
#         Args:
#             pred: 模型预测值，形状为 (batch_size, output_len, dim)
#             target: 目标值，形状为 (batch_size, output_len, dim)
        
#         Returns:
#             total_loss: 平均 Sinkhorn 损失值
#         """
#         batch_size, output_len, dim = pred.shape
        
#         # 重塑张量以便并行计算
#         # 确保张量内存连续并正确重塑
#         pred_reshaped = pred.transpose(0, 1).contiguous()    # (output_len, batch_size, dim)
#         target_reshaped = target.transpose(0, 1).contiguous() # (output_len, batch_size, dim)
        
#         # 将每个时间步的预测和目标展平为二维张量
#         pred_flat = pred_reshaped.reshape(-1, dim)      # (output_len * batch_size, dim)
#         target_flat = target_reshaped.reshape(-1, dim)  # (output_len * batch_size, dim)
        
#         # 计算所有时间步的 Sinkhorn 散度
#         sinkhorn_loss = self.sinkhorn(pred_flat, target_flat)
        
#         return sinkhorn_loss
