import torch
from torch import Tensor
from torch.nn import Module
import numpy as np
from scipy.optimize import linear_sum_assignment


def bce(input, target):
    if isinstance(target, float):
        target = torch.full_like(input, fill_value=target)
    return torch.nn.functional.binary_cross_entropy(input, target, reduction="none")


class HungarianLoss(Module):
    def __init__(self, lambd: float, eps: float):
        super().__init__()
        self.eps = eps
        self.lambd = lambd

    def forward(self, x: Tensor, x_gt: Tensor):
        bs, N = x.size(0), x.size(1)

        m, x = x[..., 0], x[..., 1:]
        # d = x.size(-1) // 2
        # x, u = x[..., :d], x[..., d:]

        # compute focal loss
        pos_loss_m = -torch.log(m.clip(min=self.eps))
        neg_loss_m = -torch.log((1.0 - m).clip(min=self.eps))

        # compute the localisation cost
        v = torch.tensor([100, 100, 200, 1000, 1000, 1000])
        v = v.to(device=x.device, dtype=x.dtype)
        x_gt = x_gt.reshape(-1, -1, 1, -1)
        M = x[:, None] - x_gt
        M = M * v.reciprocal()
        M = M.square()

        # Cost matric, only use (dx, dy, dz)
        with torch.no_grad():
            cost_loc = M[..., :2].mean(dim=-1)
            cost_mask = pos_loss_m - neg_loss_m
            c = self.lambd * cost_loc + (1.0 - self.lambd) * cost_mask[:, None]
            # c = cost_loc
            c = c.cpu()
            idxs = [linear_sum_assignment(e.numpy()) for e in c.unbind()]

        # Loss matrix, sum on everything
        pos_loss = (
            self.lambd * M.mean(dim=-1) + (1.0 - self.lambd) * pos_loss_m[:, None]
        )
        pos_loss = torch.stack(
            [e[gt_idx, p_idx].sum() for e, (gt_idx, p_idx) in zip(pos_loss, idxs)]
        )

        all_p_idx = np.arange(N)
        idxs = [np.setdiff1d(all_p_idx, p_idx) for gt_idx, p_idx in idxs]
        neg_loss = torch.stack([c[idx].sum() for c, idx in zip(neg_loss_m, idxs)])
        neg_loss = (1.0 - self.lambd) * neg_loss

        pos_loss = pos_loss.mean()
        neg_loss = neg_loss.mean()

        return pos_loss, neg_loss
