
import torch
import torch.nn as nn

import ot

# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps=0.1, max_iter=2000, thresh=1e-3, reduction='mean', device='cpu'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.thresh = thresh
        self.reduction = reduction
        self.device = device
        print('=============== SinkHorn ===============')
        print(f'========= epsilon:{self.eps}')
        print(f'========= max iteration:{self.max_iter}')
        print(f'========= stop threshold:{self.thresh}')
        print('=============== SinkHorn ===============')

    def forward(self, x, y, normalized=False):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y, normalized=normalized)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).squeeze()

        u = torch.zeros_like(mu).to(self.device)
        v = torch.zeros_like(nu).to(self.device)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = self.thresh
        # thresh = 1e-3

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8).to(self.device) - torch.logsumexp(self.M(C, u, v).to(self.device), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8).to(self.device) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1).to(self.device), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                # print(f'error:{err.item()}')
                break
        # if actual_nits == self.max_iter:
        #     print('meeting max iteration.')            
        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        # return cost, pi, C
        return cost

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2, normalized=False):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        if normalized:
            C = C / torch.norm(C)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1


class SinkhornLoss2(nn.Module):
    def __init__(self, eps=0.1, max_iter=1000, device='cpu'):
        super(SinkhornLoss2, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.device = device
        
    def forward(self, x, y):
        # x = x.cpu().data.numpy()
        # y = y.cpu().data.numpy()

        loss = ot.bregman.empirical_sinkhorn2(x, y, reg=1.0, numIterMax=1000, stopThr=1e-3)
        # loss = ot.bregman.empirical_sinkhorn_divergence(x, y, reg=100)
        # loss = ot.sinkhorn2(x, y, ot.dist(x, y), reg=0.1)
        # print(loss.shape, loss)

        return loss


if __name__ == "__main__":
    a = torch.randn(128, 2, 1, requires_grad=True)
    # a = a / torch.sum(a)
    b = torch.rand(128, 2, 1) *5 + 5
    # b = b / torch.sum(b)

    sh = SinkhornDistance()
    l = sh(a, b)
    # l.backward()
    print(l)

    