import torch
import torch.nn.functional as F

def ct_loss(x_true, x_fake, model_n):
    diff = (x_true[:,None]-x_fake).pow(2) #pairwise mse for navigator network: B x B x h 
    cost = diff.sum(-1) #pairwise cost: B x B
    tmp = model_n(diff).squeeze() # navigator distance: B x B
    m_backward = torch.nn.functional.softmax(tmp, dim=0) # backward map
    m_forward = torch.nn.functional.softmax(tmp, dim=1) # forward map
    gloss = (cost * m_forward).sum(1).mean() # forward transport
    nloss = (cost * m_backward).sum(0).mean() # backward transport
    return gloss, nloss


def ct_withd_loss(x_true, x_fake, model_d, model_n):
    feat_true = F.normalize(model_d(x_true), dim=-1)
    feat_fake = F.normalize(model_d(x_fake), dim=-1)
    diff = (feat_true[:,None]-feat_fake).pow(2) #pairwise mse for navigator network: B x B x h 
    cost = diff.sum(-1) #pairwise cost: B x B
    tmp = model_n(diff).squeeze() # navigator distance: B x B
    m_backward = torch.nn.functional.softmax(tmp, dim=0) # backward map
    m_forward = torch.nn.functional.softmax(tmp, dim=1) # forward map
    gloss = (cost * m_forward).sum(1).mean() # forward transport
    nloss = (cost * m_backward).sum(0).mean() # backward transport
    return gloss, nloss
