import torch
from torch import Tensor
from torch.nn import Module
import numpy as np
from scipy.optimize import linear_sum_assignment


class DETRLoss(Module):
    def __init__(self, inv_sigma: Tensor, eps: float):
        super().__init__()
        if inv_sigma.ndim != 1 or inv_sigma.size(0) != 6:
            raise AttributeError("inv_sigma_data must be a 1d tensor of shape [6].")
        self.register_buffer("inv_sigma", inv_sigma, persistent=False)
        self.eps = eps

    def forward(self, x: Tensor, x_gt: list[Tensor]):
        bs, N = x.size(0), x.size(1)

        m, x = x[..., 0], x[..., 1:]
        m_eps = torch.clamp(m, min=self.eps, max=1.0 - self.eps)

        log_m = torch.log(m_eps)
        log1m_m = torch.log1p(-m_eps)

        x_gt = x_gt.reshape(bs, -1, 1, 6)
        M = x[:, None] - x_gt
        M = M * self.inv_sigma
        M = M.square()

        # hungarian alorithm only on (dx, dy, dz)
        c = M[..., :3].mean(dim=-1)
        c = c - log_m[:, None]
        c = c.detach().cpu()
        idxs = [linear_sum_assignment(e.numpy()) for e in c.unbind()]

        # cost matrix on everything
        M = M.mean(dim=-1)
        M = M - log_m[:, None]
        l_match = torch.stack(
            [e[gt_idx, p_idx].sum() for e, (gt_idx, p_idx) in zip(M, idxs)]
        )
        all_p_idx = np.arange(N)
        idxs = [np.setdiff1d(all_p_idx, p_idx) for gt_idx, p_idx in idxs]
        l_no_match = torch.stack([-cost[idx].sum() for cost, idx in zip(log1m_m, idxs)])

        l_match = l_match.mean()
        l_no_match = l_no_match.mean()
        loss = l_match + l_no_match

        return {"loss": loss, "pos": l_match, "neg": l_no_match}


def bce(input, target):
    if isinstance(target, float):
        target = torch.full_like(input, fill_value=target)
    return torch.nn.functional.binary_cross_entropy(input, target, reduction="none")


class DETRLossWithUncertainties(Module):
    def __init__(self, lambd: float, eps: float):
        super().__init__()
        self.eps = eps
        self.lambd = lambd

    def forward(self, x: Tensor, x_gt: list[Tensor]):
        bs, N = x.size(0), x.size(1)

        m, x = x[..., 0], x[..., 1:]

        # compute focal loss
        alpha = 0.25
        gamma = 2.0
        pos_loss_m = alpha * ((1.0 - m) ** gamma) * bce(m, 1.0)
        neg_loss_m = (1.0 - alpha) * (m**gamma) * bce(1.0 - m, 1.0)

        # compute the localisation cost
        d = x.size(-1) // 2
        x, u = x[..., :d], x[..., d:]
        x_gt = x_gt.reshape(bs, -1, 1, d)
        M = x[:, None] - x_gt
        M = M * u[:, None].reciprocal()
        M = 0.5 * M.square()
        M = M + torch.log(u.clip(min=self.eps))[:, None]

        # Cost matric, only use (dx, dy, dz)
        with torch.no_grad():
            cost_loc = M[..., :3].mean(dim=-1)
            cost_mask = pos_loss_m - neg_loss_m
            c = self.lambd * cost_loc + (1.0 - self.lambd) * cost_mask[:, None]
            c = c.cpu()
            idxs = [linear_sum_assignment(e.numpy()) for e in c.unbind()]

        # Loss matrix, sum on everything
        M = self.lambd * M.mean(dim=-1) + (1.0 - self.lambd) * pos_loss_m[:, None]
        pos_loss = torch.stack(
            [e[gt_idx, p_idx].sum() for e, (gt_idx, p_idx) in zip(M, idxs)]
        )

        all_p_idx = np.arange(N)
        idxs = [np.setdiff1d(all_p_idx, p_idx) for gt_idx, p_idx in idxs]
        neg_loss = torch.stack([c[idx].sum() for c, idx in zip(neg_loss_m, idxs)])
        neg_loss = (1.0 - self.lambd) * neg_loss

        pos_loss = pos_loss.mean()
        neg_loss = neg_loss.mean()

        return pos_loss, neg_loss


def bce_nested(input: Tensor, target: Tensor, eps: float):
    if (isinstance(target, Tensor) and (target.min() < 0.0 or target.max() > 1.0)) or (
        isinstance(target, float) and (target < 0.0 or target > 1.0)
    ):
        raise ValueError("target must be in [0, 1]")
    if input.min() < 0.0 or input.max() > 1.0:
        raise ValueError("input must be in [0, 1]")

    x = torch.clamp(input, min=eps, max=1.0 - eps)
    y = target

    loss = -y * x.log() - (1.0 - y) * (1.0 - x).log()
    return loss


class DETRLossWithUncertaintiesStableMatching(Module):
    # https://openaccess.thecvf.com/content/ICCV2023/papers/Liu_Detection_Transformer_with_Stable_Matching_ICCV_2023_paper.pdf
    def __init__(self, eps: float, lambd: float):
        super().__init__()
        self.eps = 1e-6
        self.lambd = lambd

    def forward(self, x: Tensor, x_gt: list[Tensor]):
        bs, N = x.size(0), x.size(1)

        m, x = x[..., 0], x[..., 1:]

        # compute focal loss
        a = 0.25
        g = 2.0

        # compute the localisation cost
        d = x.size(-1) // 2
        x, u = x[..., :d], x[..., d:]
        x_gt = x_gt.reshape(bs, -1, 1, d)
        M = x[:, None] - x_gt
        M = M * u[:, None].reciprocal()
        M = 0.5 * M.square()
        M = M + torch.log(u.clip(min=self.eps))[:, None]

        # Cost matric, only use (dx, dy, dz)
        with torch.no_grad():
            cost_loc = M[..., :3].mean(dim=-1)

            C = cost_loc
            C_min = torch.stack([e.min() for e in C], dim=0)
            C = C - C_min[:, None, None]
            f2 = torch.exp(-C).sqrt()

            m2 = f2 * m[:, None]  # m is reduced by the quality of the prediction
            pos_loss_m = a * ((1.0 - m2) ** g) * bce_nested(m2, 1.0, eps=self.eps)
            neg_loss_m = (1.0 - a) * (m2**g) * bce_nested(1.0 - m2, 1.0, eps=self.eps)

            cost_mask = pos_loss_m - neg_loss_m
            c = self.lambd * cost_loc + (1.0 - self.lambd) * cost_mask
            c = c.cpu()
            idxs = [linear_sum_assignment(e.numpy()) for e in c.unbind()]

        # Loss matrix, sum on everything
        loss_loc = M.mean(dim=-1)

        L = loss_loc
        L_min = torch.stack([e.min() for e in L], dim=0).min()
        L = L - L_min
        f1 = torch.exp(-L)

        # m targets the quality of loss_loc (f1)
        pos_loss_m = (
            a * ((f1 - m[:, None]) ** g) * bce_nested(m[:, None], f1, eps=self.eps)
        )
        neg_loss_m = (1.0 - a) * (m**g) * bce(1.0 - m, 1.0)

        pos_loss = self.lambd * loss_loc + (1.0 - self.lambd) * pos_loss_m
        pos_loss = torch.stack(
            [e[gt_idx, p_idx].sum() for e, (gt_idx, p_idx) in zip(pos_loss, idxs)]
        )

        all_p_idx = np.arange(N)
        idxs = [np.setdiff1d(all_p_idx, p_idx) for gt_idx, p_idx in idxs]
        neg_loss = torch.stack([c[idx].sum() for c, idx in zip(neg_loss_m, idxs)])
        neg_loss = (1.0 - self.lambd) * neg_loss

        pos_loss = pos_loss.mean()
        neg_loss = neg_loss.mean()

        return pos_loss, neg_loss


def main():
    pass


if __name__ == "__main__":
    main()
