import math
import numpy as np
import torch
import torch.nn as nn
from scipy.optimize import linear_sum_assignment
import torch.nn.functional as F
import torch.distributed as dist
import ipdb
st = ipdb.set_trace


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def smoothl1_loss(error, delta=1.0):
    """Smooth L1 loss.
    x = error = pred - gt or dist(pred,gt)
    0.5 * |x|^2                 if |x|<=d
    |x| - 0.5 * d               if |x|>d
    """
    diff = torch.abs(error)
    loss = torch.where(diff < delta, 0.5 * diff * diff / delta, diff - 0.5 * delta)
    return loss


def l1_loss(error):
    loss = torch.abs(error)
    return loss


def box_cxcyczwhd_to_xyzxyz(x):
    x_c, y_c, z_c, w, h, d = x.unbind(-1)
    w = torch.clamp(w, min=1e-6)
    h = torch.clamp(h, min=1e-6)
    d = torch.clamp(d, min=1e-6)
    assert (w < 0).sum() == 0
    assert (h < 0).sum() == 0
    assert (d < 0).sum() == 0
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (z_c - 0.5 * d),
         (x_c + 0.5 * w), (y_c + 0.5 * h), (z_c + 0.5 * d)]
    return torch.stack(b, dim=-1)


def _iou3d(box_a, box_b):
    intersection = _intersect(box_a, box_b)
    vol_a = _volume(box_a)
    vol_b = _volume(box_b)
    union = vol_a + vol_b - intersection
    return intersection / union, union


def _intersect(box_a, box_b):
    # print(box_a)
    xA = max(box_a[0], box_b[0])
    yA = max(box_a[1], box_b[1])
    zA = max(box_a[2], box_b[2])
    xB = min(box_a[3], box_b[3])
    yB = min(box_a[4], box_b[4])
    zB = min(box_a[5], box_b[5])
    return max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA)


def _volume(box):
    return (box[3] - box[0]) * (box[4] - box[1]) * (box[5] - box[2])


def _volume_par(box):
    return (box[:, 3] - box[:, 0]) * (box[:, 4] - box[:, 1]) * (box[:, 5] - box[:, 2])


def _intersect_par(box_a, box_b):
    xA = torch.max(box_a[:, 0][:, None], box_b[:, 0][None, :])
    yA = torch.max(box_a[:, 1][:, None], box_b[:, 1][None, :])
    zA = torch.max(box_a[:, 2][:, None], box_b[:, 2][None, :])
    xB = torch.min(box_a[:, 3][:, None], box_b[:, 3][None, :])
    yB = torch.min(box_a[:, 4][:, None], box_b[:, 4][None, :])
    zB = torch.min(box_a[:, 5][:, None], box_b[:, 5][None, :])
    return torch.clamp(xB - xA, 0) * torch.clamp(yB - yA, 0) * torch.clamp(zB - zA, 0)


def _iou3d_par(box_a, box_b):
    intersection = _intersect_par(box_a, box_b)
    vol_a = _volume_par(box_a)
    vol_b = _volume_par(box_b)
    union = vol_a[:, None] + vol_b[None, :] - intersection
    return intersection / union, union


def generalized_box_iou3d(boxes1, boxes2):
    """
    Generalized IoU from https://giou.stanford.edu/
    The boxes should be in [x0, y0, x1, y1] format
    Returns a [N, M] pairwise matrix, where N = len(boxes1)
    and M = len(boxes2)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check

    assert (boxes1[:, 3:] >= boxes1[:, :3]).all()
    assert (boxes2[:, 3:] >= boxes2[:, :3]).all()
    '''
    N = boxes1.shape[0]
    M = boxes2.shape[0]
    iou = torch.zeros((N,M)).to(boxes1.device)
    union = torch.zeros((N,M)).to(boxes1.device)
    for n in range(N):
        for m in range(M):
            iou[n,m], union[n,m] = _iou3d(boxes1[n], boxes2[m])
    '''
    iou, union = _iou3d_par(boxes1, boxes2)

    lt = torch.min(boxes1[:, None, :3], boxes2[:, :3])
    rb = torch.max(boxes1[:, None, 3:], boxes2[:, 3:])

    wh = (rb - lt).clamp(min=0)  # [N,M,3]
    volume = wh[:, :, 0] * wh[:, :, 1] * wh[:, :, 2]

    return iou - (volume - union) / volume


class SigmoidFocalClassificationLoss(nn.Module):
    """
    Sigmoid focal cross entropy loss.
    """

    def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
        """
        Args:
            gamma: Weighting parameter to balance loss for hard and easy examples.
            alpha: Weighting parameter to balance loss for positive and negative examples.
        """
        super(SigmoidFocalClassificationLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    @staticmethod
    def sigmoid_cross_entropy_with_logits(input: torch.Tensor, target: torch.Tensor):
        """ PyTorch Implementation for tf.nn.sigmoid_cross_entropy_with_logits:
            max(x, 0) - x * z + log(1 + exp(-abs(x))) in
            https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

        Args:
            input: (B, #proposals, #classes) float tensor.
                Predicted logits for each class
            target: (B, #proposals, #classes) float tensor.
                One-hot encoded classification targets

        Returns:
            loss: (B, #proposals, #classes) float tensor.
                Sigmoid cross entropy loss without reduction
        """
        loss = torch.clamp(input, min=0) - input * target + \
               torch.log1p(torch.exp(-torch.abs(input)))
        return loss

    def forward(self, input: torch.Tensor, target: torch.Tensor, weights: torch.Tensor):
        """
        Args:
            input: (B, #proposals, #classes) float tensor.
                Predicted logits for each class
            target: (B, #proposals, #classes) float tensor.
                One-hot encoded classification targets
            weights: (B, #proposals) float tensor.
                Anchor-wise weights.

        Returns:
            weighted_loss: (B, #proposals, #classes) float tensor after weighting.
        """
        pred_sigmoid = torch.sigmoid(input)
        alpha_weight = target * self.alpha + (1 - target) * (1 - self.alpha)
        pt = target * (1.0 - pred_sigmoid) + (1.0 - target) * pred_sigmoid
        focal_weight = alpha_weight * torch.pow(pt, self.gamma)

        bce_loss = self.sigmoid_cross_entropy_with_logits(input, target)

        loss = focal_weight * bce_loss

        weights = weights.unsqueeze(-1)
        assert weights.shape.__len__() == loss.shape.__len__()

        return loss * weights


class HungarianMatcher(nn.Module):
    """
    Assign targets to predictions.

    This class is taken from M-DETR and is modified for our purposes.

    For efficiency reasons, the targets don't include the no_object.
    Because of this, in general, there are more predictions than targets.
    In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class=1, cost_bbox=5, cost_giou=2,
                 soft_token=False, use_detected_boxes=False):
        """
        Initialize matcher.

        Args:
            cost_class: relative weight of the classification error
            cost_bbox: relative weight of the L1 bounding box regression error
            cost_giou: relative weight of the giou loss of the bounding box
            soft_token: whether to use soft-token prediction
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0
        self.soft_token = soft_token
        self.use_detected_boxes = use_detected_boxes

    @torch.no_grad()
    def forward(self, outputs, targets):
        """
        Perform the matching.

        Args:
            outputs: This is a dict that contains at least these entries:
                "pred_logits" (tensor): [batch_size, num_queries, num_classes]
                "pred_boxes" (tensor): [batch_size, num_queries, 6], cxcyczwhd
            targets: list (len(targets) = batch_size) of dict:
                "labels" (tensor): [num_target_boxes]
                    (where num_target_boxes is the no. of ground-truth objects)
                "boxes" (tensor): [num_target_boxes, 6], cxcyczwhd
                "positive_map" (tensor): [num_target_boxes, 256]

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j):
                - index_i is the indices of the selected predictions
                - index_j is the indices of the corresponding selected targets
            For each batch element, it holds:
            len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        # Notation: {B: batch_size, Q: num_queries, C: num_classes}
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [B*Q, C]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [B*Q, 6]

        # Also concat the target labels and boxes
        positive_map = torch.cat([t["positive_map"] for t in targets])
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # pad if necessary
        if out_prob.shape[-1] != positive_map.shape[-1]:
            out_prob_pad = torch.zeros(
                out_prob.shape[0], positive_map.shape[-1]
            ).to(out_prob.device)
            out_prob_pad[:, :out_prob.shape[-1]] = out_prob
            out_prob = out_prob_pad
        if self.soft_token:
            # cost_class = -(out_prob.unsqueeze(1) * positive_map.unsqueeze(0)).sum(-1)
            cost_class = -torch.matmul(out_prob, positive_map.transpose(0, 1))
            # Contrastive loss
            '''
            cost_contr = -(
                logits.unsqueeze(1)
                * positive_map[:, :logits.size(-1)].unsqueeze(0)
            ).sum(-1)
            '''
        else:
            # Compute the classification cost.
            # Contrary to the loss, we don't use the NLL,
            # but approximate it in 1 - proba[target class].
            # The 1 is a constant that doesn't change the matching,
            # it can be ommitted. DETR
            cost_class = -out_prob[:, tgt_ids]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

        # Compute the giou cost betwen boxes
        cost_giou = -generalized_box_iou3d(
            box_cxcyczwhd_to_xyzxyz(out_bbox),
            box_cxcyczwhd_to_xyzxyz(tgt_bbox)
        )

        # Final cost matrix
        C = (
            self.cost_bbox * cost_bbox
            + self.cost_class * cost_class
            + self.cost_giou * cost_giou
        ).view(bs, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        if self.use_detected_boxes:
            mask = torch.cat([t['mask'].unsqueeze(0) for t in targets])
            indices = [
                linear_sum_assignment(c[i][mask[i]])
                for i, c in enumerate(C.split(sizes, -1))
            ]
        else:
            indices = [
                linear_sum_assignment(c[i])
                for i, c in enumerate(C.split(sizes, -1))
            ]
        return [
            (
                torch.as_tensor(i, dtype=torch.int64),  # matched pred boxes
                torch.as_tensor(j, dtype=torch.int64)  # corresponding gt boxes
            )
            for i, j in indices
        ]


class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
                 temperature, soft_token=False,
                 contrastive_hungarian=False, use_gt_box=False,
                 new_contrastive=False, detect_intermediate=False,
                 use_detected_boxes=False, train_viewpoint_prototype=False):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            soft_token: whether to use soft-token prediction
            pessimistic: whether to use the pessimistic class loss
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        self.temperature = temperature
        self.soft_token = soft_token
        self.contrastive_hungarian = contrastive_hungarian
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)
        self.use_gt_box = use_gt_box
        self.use_detected_boxes = use_detected_boxes
        self.new_contrastive = new_contrastive
        self.detect_intermediate = detect_intermediate

    def loss_labels_st(self, outputs, targets, indices, num_boxes, log=False):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        logits = outputs["pred_logits"].log_softmax(-1)  # BS x (num_queries) x (num_tokens)
        positive_map = torch.cat([t["positive_map"] for t in targets])
        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = []
        offset = 0
        for i, (_, tgt) in enumerate(indices):
            tgt_idx.append(tgt + offset)
            offset += len(targets[i]["boxes"])
        tgt_idx = torch.cat(tgt_idx)

        tgt_pos = positive_map[tgt_idx]
        target_sim = torch.zeros_like(logits)
        target_sim[:, :, -1] = 1
        target_sim[src_idx] = tgt_pos

        entropy = torch.log(target_sim+1e-6) * target_sim
        loss_ce = (entropy - logits * target_sim).sum(-1)

        eos_coef = torch.full(loss_ce.shape, self.eos_coef, device=target_sim.device)
        eos_coef[src_idx] = 1

        loss_ce = loss_ce * eos_coef

        if self.use_gt_box or self.use_detected_boxes:
            mask = torch.cat([t['mask'].unsqueeze(0) for t in targets])
            loss_ce = loss_ce * mask

        loss_ce = loss_ce.sum() / num_boxes

        losses = {"loss_ce": loss_ce}

        return losses

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        return losses

    def loss_labels_pess(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        # Keep only the 32 most pessimistic and optimistic scores
        num = max(32, max(len(t["labels"]) for t in targets))
        inds = src_logits[:, :, -1].argsort(1)[:, -num:]
        for i, ind in enumerate(indices):
            inds[i, :len(ind[0])] = ind[0]
        batch_idx = torch.cat([torch.full_like(src, i) for i, src in enumerate(inds)])
        src_idx = torch.cat([src for src in inds])

        loss_ce = F.cross_entropy(
            src_logits[batch_idx, src_idx],
            target_classes[batch_idx, src_idx],
            self.empty_weight
        )
        losses = {'loss_ce': loss_ce}

        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        # TODO: Should I normalize 3D boxes? How?

        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(generalized_box_iou3d(
            box_cxcyczwhd_to_xyzxyz(src_boxes),
            box_cxcyczwhd_to_xyzxyz(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def loss_masks(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the masks.

        targets dicts must contain the key "masks": [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs
        src_masks = outputs["pred_masks"]
        masks = targets  # torch.stack([t["masks"] for t in targets])
        return {"loss_mask": F.cross_entropy(src_masks.transpose(1, 2), masks)}

    def loss_contrastive_align(self, outputs, targets, indices, num_boxes):
        """Compute contrastive losses between projected queries and tokens."""
        tokenized = outputs["tokenized"]
        if self.use_gt_box:
            mask = torch.cat([t['mask'].unsqueeze(0) for t in targets])

        if self.contrastive_hungarian:
            logits = outputs["pred_logits"]
        else:
            norm_text_emb = outputs["proj_tokens"]  # B, num_tokens, dim
            norm_img_emb = outputs["proj_queries"]  # B, num_queries, dim
            logits = (
                torch.matmul(norm_img_emb, norm_text_emb.transpose(-1, -2))
                / self.temperature
            )  # B, num_queries, num_tokens

        # construct a map such that positive_map[k, i, j] = True
        # iff query i is associated to token j in batch item k
        # For efficency, the construction happens on CPU,
        # then the whole matrix is transferred to GPU in one go.
        positive_map = torch.zeros(logits.shape, dtype=torch.bool)
        for i, ((idx_src, idx_tgt), tgt) in enumerate(zip(indices, targets)):
            if "tokens_positive" in tgt:
                cur_tokens = [tgt["tokens_positive"][j] for j in idx_tgt]
            else:
                cur_tokens = [tgt["tokens"][j] for j in idx_tgt]

            for j, tok_list in enumerate(cur_tokens):
                beg, end = tok_list.cpu()
                beg_pos = tokenized.char_to_token(i, beg)
                end_pos = tokenized.char_to_token(i, end - 1)
                if beg_pos is None:
                    try:
                        beg_pos = tokenized.char_to_token(beg + 1)
                        if beg_pos is None:
                            beg_pos = tokenized.char_to_token(beg + 2)
                    except:
                        beg_pos = None
                if end_pos is None:
                    try:
                        end_pos = tokenized.char_to_token(end - 2)
                        if end_pos is None:
                            end_pos = tokenized.char_to_token(end - 3)
                    except:
                        end_pos = None
                if beg_pos is None or end_pos is None:
                    continue

                assert beg_pos is not None and end_pos is not None
                positive_map[i, idx_src[j], beg_pos:end_pos + 1].fill_(True)

        positive_map = positive_map.to(logits.device)
        positive_logits = -logits.masked_fill(~positive_map, 0)
        negative_logits = logits  # .masked_fill(positive_map, -1000000)

        boxes_with_pos = positive_map.any(2)
        pos_term = positive_logits.sum(2)
        neg_term = negative_logits.logsumexp(2)

        nb_pos = positive_map.sum(2) + 1e-6
        entropy = -torch.log(nb_pos+1e-6) / nb_pos  # entropy of 1/nb_pos
        box_to_token_loss_ = ((entropy + pos_term / nb_pos + neg_term)).masked_fill(~boxes_with_pos, 0)
        if self.use_gt_box or self.use_detected_boxes:
            box_to_token_loss = (box_to_token_loss_ * mask).sum()
        else:
            box_to_token_loss = box_to_token_loss_.sum()

        tokens_with_pos = positive_map.any(1)
        if self.use_gt_box or self.use_detected_boxes:
            pos_term = (positive_logits * mask.unsqueeze(-1)).sum(1)
            neg_term = negative_logits.masked_fill(~mask.unsqueeze(-1), -np.inf).logsumexp(1)
        else:
            pos_term = positive_logits.sum(1)
            neg_term = negative_logits.logsumexp(1)

        nb_pos = positive_map.sum(1) + 1e-6
        entropy = -torch.log(nb_pos+1e-6) / nb_pos
        tokens_to_boxes_loss = ((entropy + pos_term / nb_pos + neg_term)).masked_fill(~tokens_with_pos, 0).sum()
        tot_loss = (box_to_token_loss + tokens_to_boxes_loss) / 2

        return {"loss_contrastive_align": tot_loss / num_boxes}

    def loss_contrastive_align_new(self, outputs, targets, indices, num_boxes):
        """Compute contrastive losses between projected queries and tokens."""
        tokenized = outputs["tokenized"]

        if self.contrastive_hungarian:
            logits = outputs["pred_logits"]
        else:
            norm_text_emb = outputs["proj_tokens"]  # B, num_tokens, dim
            norm_img_emb = outputs["proj_queries"]  # B, num_queries, dim
            logits = (
                torch.matmul(norm_img_emb, norm_text_emb.transpose(-1, -2))
                / self.temperature
            )  # B, num_queries, num_tokens

        # construct a map such that positive_map[k, i, j] = True
        # iff query i is associated to token j in batch item k
        # For efficency, the construction happens on CPU,
        # then the whole matrix is transferred to GPU in one go.
        positive_map = torch.zeros(logits.shape, dtype=torch.bool)
        for i, ((idx_src, idx_tgt), tgt) in enumerate(zip(indices, targets)):
            if "tokens_positive" in tgt:
                cur_tokens = [tgt["tokens_positive"][j] for j in idx_tgt]
            else:
                cur_tokens = [tgt["tokens"][j] for j in idx_tgt]

            for j, tok_list in enumerate(cur_tokens):
                beg, end = tok_list.cpu()
                beg_pos = tokenized.char_to_token(i, beg)
                end_pos = tokenized.char_to_token(i, end - 1)
                if beg_pos is None:
                    try:
                        beg_pos = tokenized.char_to_token(beg + 1)
                        if beg_pos is None:
                            beg_pos = tokenized.char_to_token(beg + 2)
                    except:
                        beg_pos = None
                if end_pos is None:
                    try:
                        end_pos = tokenized.char_to_token(end - 2)
                        if end_pos is None:
                            end_pos = tokenized.char_to_token(end - 3)
                    except:
                        end_pos = None
                if beg_pos is None or end_pos is None:
                    continue

                assert beg_pos is not None and end_pos is not None
                positive_map[i, idx_src[j], beg_pos:end_pos + 1].fill_(True)

        positive_map = positive_map.to(logits.device)

        # Loss 1: which tokens should each query match?
        qt_target = positive_map.flatten(0, 1).float()
        qt_target[qt_target.sum(1) == 0] = 1
        qt_target = qt_target / qt_target.sum(1).unsqueeze(1)
        query_token_loss = F.kl_div(
            logits.flatten(0, 1).log_softmax(1),
            qt_target, reduction='none'
        ).sum(1).sum()

        # Loss 2: which queries should each token match?
        tq_target = positive_map.transpose(1, 2).flatten(0, 1).float()
        tq_target[tq_target.sum(1) == 0] = 1
        tq_target = tq_target / tq_target.sum(1).unsqueeze(1)
        token_query_loss = F.kl_div(
            logits.transpose(1, 2).flatten(0, 1).log_softmax(1),
            tq_target, reduction='none'
        ).sum(1).sum()

        total_loss = (query_token_loss + token_query_loss) / 2
        return {"loss_contrastive_align": total_loss / num_boxes}

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels if not self.soft_token else self.loss_labels_st,
            # 'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            # 'masks': self.loss_masks
            'contrastive_align': self.loss_contrastive_align if not self.new_contrastive else self.loss_contrastive_align_new
        }
        # if "pred_masks" in outputs:
        #    loss_map['masks'] = self.loss_masks
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        # Retrieve the matching between outputs and targets
        if self.use_gt_box and self.detect_intermediate:
            indices = [
                (
                    torch.as_tensor(
                        [t['target_id'].item()] + [_id for _id in t['anchor_ids']],
                        dtype=torch.int64
                    ),  # matched pred boxes
                    torch.as_tensor(
                        list(range(len(t['anchor_ids']) + 1)),
                        dtype=torch.int64
                    )  # corresponding gt boxes
                )
                for t in targets
            ]
        elif self.use_gt_box:
            indices = [
                (
                    torch.as_tensor([t['target_id']], dtype=torch.int64),  # matched pred boxes
                    torch.as_tensor([0], dtype=torch.int64)  # corresponding gt boxes
                )
                for t in targets
            ]
        else:
            indices = self.matcher(outputs, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        return losses, indices

class ViewpointCriterion(nn.Module):
    def __init__(self, weight, train_viewpoint_prototype=False):
        super().__init__()
        self.weight = weight
        self.train_viewpoint_prototype = train_viewpoint_prototype

    def forward(self, outputs, targets):
        pred_viewpoint = outputs['pred_viewpoint']
        target_viewpoint = torch.cat([t['target_eul'].unsqueeze(0) for t in targets])
        if not self.train_viewpoint_prototype:
            loss = F.l1_loss(pred_viewpoint, target_viewpoint, reduction='none').sum()
        else:
            bins = np.linspace(-180, 180, 12)
            target_viewpoint = np.degrees(target_viewpoint.cpu().numpy())
            gt_bins = np.digitize(
                target_viewpoint,
                bins
            ).astype(np.int64)
            loss = F.cross_entropy(
                pred_viewpoint,
                torch.from_numpy(gt_bins[:, 2]).to(pred_viewpoint.device)
            )
        return loss

