import torch
import torch.nn as nn
import torch.nn.functional as F


def sinkhorn_forward_stabilized(C, mu, nu, epsilon, max_iter, device):
    '''
    C, mu, nu: size = (B,N,2)
    '''
    bs, n ,k_ = C.size()

    f = torch.zeros([bs, n, 1])
    g = torch.zeros([bs,1, k_])
    f = f.to(device)
    g = g.to(device)

    epsilon_log_mu = epsilon * torch.log(mu)
    epsilon_log_nu = epsilon * torch.log(nu)

    def min_epsilon_row(Z, epsilon):
        return -epsilon * torch.logsumexp((-Z)/epsilon, -1, keepdim=True)

    def min_epsilon_col(Z, epsilon):
        return -epsilon * torch.logsumexp((-Z)/epsilon, -2, keepdim=True)

    for i in range(max_iter):
        f = min_epsilon_row(C-g, epsilon) + epsilon_log_mu
        g = min_epsilon_col(C-f, epsilon) + epsilon_log_nu

    Gamma = torch.exp((-C+f+g)/epsilon)
    return Gamma


def sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon):
    nu_ = nu[:,:,:-1]
    Gamma_ = Gamma[:,:,:-1]

    bs, n, k_ = Gamma.size()

    inv_mu = 1./(mu.view([1,-1])) # (1,n)
    # Kappa size: (bs, k, k)
    Kappa = torch.diag_embed(nu_.squeeze(-2)) \
            - torch.matmul(Gamma_.transpose(-1,-2) * inv_mu.unsqueeze(-2), Gamma_)

    inv_Kappa = torch.inverse(Kappa) # (bs, k, k)

    Gamma_mu = inv_mu.unsqueeze(-1) * Gamma_
    L = Gamma_mu.matmul(inv_Kappa) # (bs, n, k)
    G1 = grad_output_Gamma * Gamma # (bs, n, k+1)

    g1 = G1.sum(-1)
    G21 = (g1*inv_mu).unsqueeze(-1) * Gamma # (bs, n, k+1)
    g1_L = g1.unsqueeze(-2).matmul(L) # (bs, 1, k)
    G22 = g1_L.matmul(Gamma_mu.transpose(-1,-2)).transpose(-1,-2) * Gamma # (bs, n, k+1)
    G23 = - F.pad(g1_L, pad=(0,1), mode='constant', value=0) * Gamma # (bs, n, k+1)
    G2 = G21 + G22 + G23 # (bs, n, k+1)

    del g1, G21, G22, G23, Gamma_mu

    g2 = G1.sum(-2).unsqueeze(-1) # (bs, k+1, 1)
    g2 = g2[:,:-1,:] # (bs, k, 1)
    G31 = - L.matmul(g2) * Gamma # (bs, n, k+1)
    G32 = F.pad(inv_Kappa.matmul(g2).transpose(-1,-2), pad=(0,1),
                mode='constant', value=0) * Gamma # (bs, n, k+1)
    G3 = G31 + G32 # (bs, n, k+1)

    grad_C = (-G1+G2+G3) / epsilon # (bs, n, k+1)
    return grad_C


class TopKFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, C, mu, nu, epsilon, max_iter, device):
        with torch.no_grad():
            Gamma = sinkhorn_forward_stabilized(C, mu, nu, epsilon, max_iter, device)
            ctx.save_for_backward(mu, nu, Gamma)
            ctx.epsilon = epsilon
        return Gamma

    @staticmethod
    def backward(ctx, grad_output_Gamma):
        epsilon = ctx.epsilon
        mu, nu, Gamma = ctx.saved_tensors
        # mu (1, n, 1)
        # nu (1, 1, k+1)
        # Gamma (bs, n, k+1)
        with torch.no_grad():
            grad_C = sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon)
        return grad_C, None, None, None, None, None


class TopK_custom(nn.Module):
    def __init__(self, epsilon=1e-1, max_iter=200, device=None):
        super(TopK_custom, self).__init__()
        self.epsilon = epsilon
#         self.anchors = torch.FloatTensor([k-i for i in range(k+1)]).view([1,1,k+1])
        self.anchors = torch.FloatTensor([1,0]).view([1,1,2])
        self.max_iter = max_iter
        self.device = device

#         self.anchors = self.anchors.to(self.device)

    def forward(self, scores, k):
        bs, n = scores.size()
        scores = scores.view([bs, n, 1])

        # find the -inf value and replace it with the minimum value except -inf
        scores_ = scores.clone().detach()
        max_scores = torch.max(scores_).detach()
        scores_[scores_==float('-inf')] = float('inf')
        min_scores = torch.min(scores_).detach()
        filled_value = min_scores - (max_scores-min_scores)
        mask = scores==float('-inf')
        scores = scores.masked_fill(mask, filled_value)

        C = (scores - self.anchors.to(self.device)) ** 2
        C = C / (C.max().detach())

        mu = torch.ones([1,n,1], requires_grad=False) / n
        nu = torch.FloatTensor([k/n, (n-k)/n]).view([1, 1, 2])
#         nu = [1./n for _ in range(k)]
#         nu.append((n-k) / n)
#         nu = torch.FloatTensor(nu).view([1, 1, k+1])

        mu = mu.to(self.device)
        nu = nu.to(self.device)

        Gamma = TopKFunc.apply(C, mu, nu, self.epsilon, self.max_iter, self.device)
        A = Gamma[:,:,:1] * n
        return A
#         A = Gamma[:,:,:k] * n
#         return A, None



