import sys

import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)
import numpy as np


font2 = {'family': 'Times New Roman',
         'weight': 'normal',
         'size': 25,
         }


class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`

    """

    def __init__(self, eps, max_iter, args, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
        self.args = args

    def forward(self, x, y):
        C = self._cost_matrix(x, y)
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).cuda(self.args.gpu).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).cuda(self.args.gpu).squeeze()
        u = torch.zeros_like(mu).cuda(self.args.gpu)
        v = torch.zeros_like(nu).cuda(self.args.gpu)
        actual_nits = 0
        thresh = 1e-1

        for i in range(self.max_iter):

            u1 = u
            u = self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()
            actual_nits += 1

            U, V = u, v
            pi = torch.exp(self.M(C, U, V))

            if err.item() < thresh:
                break

        U, V = u, v
        pi = torch.exp(self.M(C, U, V))

        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    def _cost_matrix(self, x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = 1 - d_cosine(x_col, y_lin)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1


# class SinkhornDistance_one_to_multi(nn.Module):
#     r"""
#     Given two empirical measures each with :math:`P_1` locations
#     :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
#     outputs an approximation of the regularized OT cost for point clouds.
#     Args:
#         eps (float): regularization coefficient
#         max_iter (int): maximum number of Sinkhorn iterations
#         reduction (string, optional): Specifies the reduction to apply to the output:
#             'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
#             'mean': the sum of the output will be divided by the number of
#             elements in the output, 'sum': the output will be summed. Default: 'none'
#     Shape:
#         - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
#         - Output: :math:`(N)` or :math:`()`, depending on `reduction`
#
#     """
#
#     def __init__(self, eps, max_iter, reduction='none'):
#         super(SinkhornDistance_one_to_multi, self).__init__()
#         self.eps = eps
#         self.max_iter = max_iter
#         self.reduction = reduction
#
#     def forward(self, x, y):
#         C = self._cost_matrix(x, y)
#         x_points = x.shape[-2]
#         y_points = y.shape[-2]
#         if x.dim() == 2:
#             batch_size = 1
#         else:
#             batch_size = x.shape[0]
#
#         mu = torch.empty(batch_size, x_points, dtype=torch.float,
#                          requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
#         nu = torch.empty(batch_size, y_points, dtype=torch.float,
#                          requires_grad=False).fill_(1.0 / y_points).to(device)
#         u = torch.zeros_like(mu).to(device)
#         v = torch.zeros_like(nu).to(device)
#         actual_nits = 0
#         thresh = 1e-1
#
#         for i in range(self.max_iter):
#
#             u1 = u
#             u = self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
#             v = self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
#             err = (u - u1).abs().sum(-1).mean()
#             actual_nits += 1
#
#             if err.item() < thresh:
#                 break
#
#         U, V = u, v
#         pi = torch.exp(self.M(C, U, V))
#
#         cost = torch.sum(pi * C, dim=(-2, -1))
#
#         if self.reduction == 'mean':
#             cost = cost.mean()
#         elif self.reduction == 'sum':
#             cost = cost.sum()
#
#         return cost, pi, C
#
#     def M(self, C, u, v):
#         "Modified cost for logarithmic updates"
#         "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
#         return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps
#
#     def _cost_matrix(x, y, p=2):
#         "Returns the matrix of $|x_i-y_j|^p$."
#         x_col = x.unsqueeze(-2)
#         y_lin = y.unsqueeze(-3)
#         C = 1 - d_cosine(x_col, y_lin)
#         return C
#
#     @staticmethod
#     def ave(u, u1, tau):
#         "Barycenter subroutine, used by kinetic acceleration through extrapolation."
#         return tau * u + (1 - tau) * u1


class SinkhornDistance_another_one_to_multi(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`

    """

    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance_another_one_to_multi, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction
        self.given_prob_mu = True

    def forward(self, x, y, mu):
        C = self._cost_matrix(x, y)
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        if not self.given_prob_mu:
            mu = torch.empty(batch_size, x_points, dtype=torch.float,
                             requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()

        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).to(device)
        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)
        actual_nits = 0
        thresh = 1e-1

        for i in range(self.max_iter):

            u1 = u
            u = self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()
            actual_nits += 1

            if err.item() < thresh:
                break

        U, V = u, v
        pi = torch.exp(self.M(C, U, V))

        cost = torch.sum(pi * C, dim=(-2))
        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    def _cost_matrix(self, x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = 1 - d_cosine(x_col, y_lin)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

#
class SinkhornDistance_given_cost(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`

    """

    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance_given_cost, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y, given_cost):
        C = given_cost
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).to(device).squeeze()
        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)
        actual_nits = 0
        thresh = 1e-1

        for i in range(self.max_iter):

            u1 = u
            u = self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()
            actual_nits += 1

            U, V = u, v
            pi = torch.exp(self.M(C, U, V))

            if err.item() < thresh:
                break

        U, V = u, v
        pi = torch.exp(self.M(C, U, V))

        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = 1 - d_cosine(x_col, y_lin)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1
