import torch
import torch.nn as nn

class OT_loss(nn.Module):
    """
    The proposed OT-based loss.
    """
    def __init__(self,  cost_fn ='iou', beta = 1):
        super(OT_loss, self).__init__()
        self.cost_fn = cost_fn
        self.beta = beta

    def forward(self, gt_arr, sample_arr, prob, prob_gt, gamma):
        cost = self.get_cost_matrix(sample_arr.flatten(3), gt_arr.flatten(2), cost_fn=self.cost_fn) # B,N,M
        P = self.get_coupling_matrix(cost, prob_gt, gamma)
        seg_loss = (P.detach() * cost).sum([1, 2]).mean(0)
        kl_loss = self.get_kl_loss(P, prob)

        loss = seg_loss + self.beta * kl_loss
        return loss, seg_loss, kl_loss

    def get_kl_loss(self, P, prob):
        """
        Get the KL divergence loss.
        """
        target_prob = P.detach().sum(-1).detach()
        kl_loss_fn = torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='none', #batchmean
                                         log_target=False)
        kl_loss = kl_loss_fn(torch.log(prob + 1e-8), target_prob).mean()
        return kl_loss

    def get_cost_matrix(self, sample_arr, gt_arr, cost_fn = None):
        """
        Calculate the pair wise cost matrix between sample_arr and gt_arr
        using the pair-wise cost function.
        """
        B,N,C,_ = sample_arr.shape
        M = gt_arr.shape[1]

        gt_arr = torch.nn.functional.one_hot(gt_arr.long(), C).permute(0, 1, 3, 2)

        # Fast matrix operation, need more space. An alternative implementation is to use iteration.
        sample_arr_repeat = sample_arr.expand(M, B, N, C, -1).permute(1, 2, 0, 3, 4)
        gt_arr_repeat = gt_arr.expand(N, B, M, C, -1).permute(1, 0, 2, 3, 4)

        if cost_fn == 'ce':
            loss_fn = torch.nn.LogSoftmax(dim=3)
            negative_logsoftmax = -loss_fn(sample_arr_repeat)
            del sample_arr_repeat
            cost = (negative_logsoftmax.mul(gt_arr_repeat)).sum(-2)[:,:,:,1:].mean(-1) # exclude bg
            del gt_arr_repeat
        elif cost_fn == 'iou':
            intersection = (sample_arr_repeat * gt_arr_repeat).sum(-1)
            union = (sample_arr_repeat.sum(-1) + gt_arr_repeat.sum(-1)) - intersection
            del sample_arr_repeat, gt_arr_repeat
            cost = 1.0-((intersection + 1) / (union + 1))[:,:,:,1:].mean(-1) # exclude bg', 1-iou
        return cost

    def get_coupling_matrix(self, cost, prob_gt, gamma):
        """
        Solve the coupling matrix in our problem.
        """
        B, N, M = cost.shape
        # greedy algorithm O(N)
        if gamma == 1:
            # moving the mass of each ground truth label to its nearest prediction
            P = (torch.nn.functional.one_hot(cost.argmin(-2), N) * prob_gt.expand(N, -1, -1)
                  .permute(1, 2,0)).transpose(-1, -2)

        # A standard linear programming problem.
        else:
            # We adopt a greedy strategy which emPrically works fine. O(NMlogN)
            P = torch.zeros_like(cost)
            sort_gt_ind = cost.min(-2).values.argsort()
            for b in range(B):
                for i in sort_gt_ind[b]:
                    j_list = cost[b, :, i].argsort()  # for gt i, sort its cost with each sample.
                    for j in j_list:
                        if P[b, j].sum() < gamma:
                            P[b, j, i] = min(gamma - P[b, j].sum(), prob_gt[b,i])
                            prob_gt[b, i] -=  P[b, j, i]
                            if prob_gt[b,i] == 0:
                                break
                        continue
        return P