class Matcher(object):
    """
    This class assigns to each predicted "element" (e.g., a box) a ground-truth
    element. Each predicted element will have exactly zero or one matches; each
    ground-truth element may be matched to zero or more predicted elements.

    The matching is determined by the MxN match_quality_matrix, that characterizes
    how well each (ground-truth, prediction)-pair match each other. For example,
    if the elements are boxes, this matrix may contain box intersection-over-union
    overlap values.

    The matcher returns (a) a vector of length N containing the index of the
    ground-truth element m in [0, M) that matches to prediction n in [0, N).
    (b) a vector of length N containing the labels for each prediction.
    """

    def __init__(
            self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
    ):
        """
        Args:
            thresholds (list): a list of thresholds used to stratify predictions
                into levels.
            labels (list): a list of values to label predictions belonging at
                each level. A label can be one of {-1, 0, 1} signifying
                {ignore, negative class, positive class}, respectively.
            allow_low_quality_matches (bool): if True, produce additional matches
                for predictions with maximum match quality lower than high_threshold.
                See set_low_quality_matches_ for more details.

            For example,
                thresholds = [0.3, 0.5]
                labels = [0, -1, 1]
                All predictions with iou < 0.3 will be marked with 0 and
                thus will be considered as false positives while training.
                All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
                thus will be ignored.
                All predictions with 0.5 <= iou will be marked with 1 and
                thus will be considered as true positives.
        """
        # Add -inf and +inf to first and last position in thresholds
        thresholds = thresholds[:]
        assert thresholds[0] > 0
        thresholds.insert(0, -float("inf"))
        thresholds.append(float("inf"))
        # Currently torchscript does not support all + generator
        assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]), thresholds
        assert all([l in [-1, 0, 1] for l in labels])
        assert len(labels) == len(thresholds) - 1
        self.thresholds = thresholds
        self.labels = labels
        self.allow_low_quality_matches = allow_low_quality_matches

    def __call__(self, match_quality_matrix):
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
                pairwise quality between M ground-truth elements and N predicted
                elements. All elements must be >= 0 (due to the us of `torch.nonzero`
                for selecting indices in :meth:`set_low_quality_matches_`).

        Returns:
            matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
                ground-truth index in [0, M)
            match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
                whether a prediction is a true or false positive or ignored
        """
        assert match_quality_matrix.dim() == 2
        if match_quality_matrix.numel() == 0:
            default_matches = match_quality_matrix.new_full(
                (match_quality_matrix.size(1),), 0, dtype=torch.int64
            )
            # When no gt boxes exist, we define IOU = 0 and therefore set labels
            # to `self.labels[0]`, which usually defaults to background class 0
            # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
            default_match_labels = match_quality_matrix.new_full(
                (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
            )
            return default_matches, default_match_labels

        assert torch.all(match_quality_matrix >= 0)

        # match_quality_matrix is M (gt) x N (predicted)
        # Max over gt elements (dim 0) to find best gt candidate for each prediction
        matched_vals, matches = match_quality_matrix.max(dim=0)

        match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)

        for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
            low_high = (matched_vals >= low) & (matched_vals < high)
            match_labels[low_high] = l

        if self.allow_low_quality_matches:
            self.set_low_quality_matches_(match_labels, match_quality_matrix)

        return matches, match_labels

    def set_low_quality_matches_(self, match_labels, match_quality_matrix, k=1):
        """
        Produce additional matches for predictions that have only low-quality matches.
        Specifically, for each ground-truth G find the set of predictions that have
        maximum overlap with it (including ties); for each prediction in that set, if
        it is unmatched, then match it to the ground-truth G.

        This function implements the RPN assignment case (i) in Sec. 3.1.2 of
        :paper:`Faster R-CNN`.
        """
        # # For each gt, find the prediction with which it has highest quality
        # highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
        # # Find the highest quality match available, even if it is low, including ties.
        # # Note that the matches qualities must be positive due to the use of
        # # `torch.nonzero`.
        # _, pred_inds_with_highest_quality = nonzero_tuple(
        #     match_quality_matrix == highest_quality_foreach_gt[:, None]
        # )
        # # If an anchor was labeled positive only due to a low-quality match
        # # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B.
        # # This follows the implementation in Detectron, and is found to have no significant impact.
        # match_labels[pred_inds_with_highest_quality] = 1
        highest_quality_foreach_gt_inds = match_quality_matrix.topk(k=k, dim=1)[1]
        match_labels[highest_quality_foreach_gt_inds.flatten()] = 1


# from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/sampling.py#L9
def subsample_labels(
        labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
):
    """
    Return `num_samples` (or fewer, if not enough found)
    random samples from `labels` which is a mixture of positives & negatives.
    It will try to return as many positives as possible without
    exceeding `positive_fraction * num_samples`, and then try to
    fill the remaining slots with negatives.

    Args:
        labels (Tensor): (N, ) label vector with values:
            * -1: ignore
            * bg_label: background ("negative") class
            * otherwise: one or more foreground ("positive") classes
        num_samples (int): The total number of labels with value >= 0 to return.
            Values that are not sampled will be filled with -1 (ignore).
        positive_fraction (float): The number of subsampled labels with values > 0
            is `min(num_positives, int(positive_fraction * num_samples))`. The number
            of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
            In order words, if there are not enough positives, the sample is filled with
            negatives. If there are also not enough negatives, then as many elements are
            sampled as is possible.
        bg_label (int): label index of background ("negative") class.

    Returns:
        pos_idx, neg_idx (Tensor):
            1D vector of indices. The total length of both is `num_samples` or fewer.
    """
    positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] # 17
    negative = nonzero_tuple(labels == bg_label)[0] # 883

    num_pos = int(num_samples * positive_fraction) # 75
    # protect against not enough positive examples
    num_pos = min(positive.numel(), num_pos) # 17
    num_neg = num_samples - num_pos # 283
    # protect against not enough negative examples
    num_neg = min(negative.numel(), num_neg) # 283

    # randomly select positive and negative examples
    perm1 = torch.randperm(positive.numel() )[:num_pos].cuda() # 17
    perm2 = torch.randperm(negative.numel() )[:num_neg].cuda() # 283

    pos_idx = positive[perm1]
    neg_idx = negative[perm2]
    return pos_idx, neg_idx


def sample_topk(pr_inds, gt_inds, cost_matrix, k):
    """
    pr_inds (tensor): tensor of shape (M,)
    gt_inds (tensor): tensor of shape (M,)
    cost_matrix (tensor): tensor of shape (num_targets, num_queries)
    """
    if len(gt_inds) == 0:
        return pr_inds, gt_inds
    # find topk matches for each gt
    gt_inds2, counts = gt_inds.unique(return_counts=True) 
    scores, pr_inds2 = cost_matrix[gt_inds2].topk(k, dim=1) 
    gt_inds2 = gt_inds2[:, None].repeat(1, k) 

    # filter to as many matches that gt has
    pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
    gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])
    return pr_inds3, gt_inds3







class Stage2AssignerRel(nn.Module):
    def __init__(self, num_queries, max_k=6):
        super().__init__()
        self.positive_fraction = 0.25
        self.bg_label = 400  # number > 91 to filter out later
        self.batch_size_per_image = num_queries
        self.proposal_matcher = Matcher(thresholds=[0.4], labels=[0, 1], allow_low_quality_matches=True)
        self.k = max_k

    def _sample_proposals(
            self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor
    ):
        """
        Based on the matching between N proposals and M groundtruth,
        sample the proposals and set their classification labels.

        Args:
            matched_idxs (Tensor): a vector of length N, each is the best-matched
                gt index in [0, M) for each proposal.
            matched_labels (Tensor): a vector of length N, the matcher's label
                (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
            gt_classes (Tensor): a vector of length M.

        Returns:
            Tensor: a vector of indices of sampled proposals. Each is in [0, N).
            Tensor: a vector of the same length, the classification label for
                each sampled proposal. Each sample is labeled as either a category in
                [0, num_classes) or the background (num_classes).
        """
        has_gt = gt_classes.numel() > 0
        # Get the corresponding GT for each proposal
        if has_gt:
            gt_classes = gt_classes[matched_idxs] # 300
            # Label unmatched proposals (0 label from matcher) as background (label=num_classes)
            gt_classes[matched_labels == 0] = self.bg_label   
            # Label ignore proposals (-1 label)
            gt_classes[matched_labels == -1] = -1
        else:
            gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
        sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
            gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label
        )

        sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
        return sampled_idxs, gt_classes[sampled_idxs]

    @torch.no_grad()
    def get_cost_matrix_rel(self, pred_logits_sub, pred_logits_obj, pred_boxes_sub, pred_boxes_obj, gt_boxes_sub, gt_boxes_obj, gt_classes_sub, gt_classes_obj, gt_rel):
        num_queries = len(pred_logits_sub) # 300

        out_prob_sub = pred_logits_sub.sigmoid() # [300,151]
        out_bbox_sub = pred_boxes_sub  # [300,4]
        out_prob_obj = pred_logits_obj.sigmoid() 
        out_bbox_obj = pred_boxes_obj

        cost_box_sub = box_iou(box_cxcywh_to_xyxy(out_bbox_sub), box_cxcywh_to_xyxy(gt_boxes_sub))[0] # [300,num_gt]
        cost_class_sub = out_prob_sub[:, gt_classes_sub]  # [300,num_gt]

        cost_box_obj = box_iou(box_cxcywh_to_xyxy(out_bbox_obj), box_cxcywh_to_xyxy(gt_boxes_obj))[0] 
        cost_class_obj = out_prob_obj[:, gt_classes_obj] 

        C_sub = 0.7 * cost_box_sub + 0.3 * cost_class_sub
        C_obj = 0.7 * cost_box_obj + 0.3 * cost_class_obj
        C = (C_sub+C_obj)/2.0
        C = C.view(num_queries, -1) 
        return C.T

    def forward(self, outputs, targets, return_cost_matrix=False):
        # VG categories are from 1 to 150. They set num_classes=151 and apply sigmoid.
        bs = len(targets) # 2
        indices = []
        cost_matrices = []

        for b in range(bs):
            sub_inds = targets[b]['rel_annotations'][:, 0]
            obj_inds = targets[b]['rel_annotations'][:, 1]
            gt_boxes_sub = targets[b]['boxes'][sub_inds]
            gt_boxes_obj = targets[b]['boxes'][obj_inds]
            gt_classes_sub = targets[b]['labels'][sub_inds]
            gt_classes_obj = targets[b]['labels'][obj_inds]
            gt_rel = targets[b]['rel_annotations'][:, 2]

            pred_logits_sub = outputs['pred_logits_sub_cdecoder'][b].detach()  # 300 151
            pred_logits_obj = outputs['pred_logits_obj_cdecoder'][b].detach()  # 300 151

            pred_boxes_sub = outputs['pred_boxes_sub'][b] # 300 4
            pred_boxes_obj = outputs['pred_boxes_obj'][b] # 300 4

            cost_matrix = self.get_cost_matrix_rel( pred_logits_sub=pred_logits_sub, 
                                                    pred_logits_obj=pred_logits_obj, 
                                                    pred_boxes_sub=pred_boxes_sub, 
                                                    pred_boxes_obj=pred_boxes_obj, 
                                                    gt_boxes_sub=gt_boxes_sub, 
                                                    gt_boxes_obj=gt_boxes_obj, 
                                                    gt_classes_sub=gt_classes_sub, 
                                                    gt_classes_obj=gt_classes_obj, 
                                                    gt_rel=gt_rel)
            
            matched_idxs, matched_labels = self.proposal_matcher(cost_matrix)



            sampled_idxs, sampled_gt_classes = self._sample_proposals(
                matched_idxs, matched_labels, gt_rel
            )

            pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label] 
            pos_gt_inds = matched_idxs[pos_pr_inds] 
            pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, cost_matrix) 
            indices.append((pos_pr_inds, pos_gt_inds))
            cost_matrices.append(cost_matrix)

        if return_cost_matrix:
            return indices, cost_matrices
        return indices
    
    def postprocess_indices(self, pr_inds, gt_inds, iou):
        return sample_topk(pr_inds, gt_inds, iou, self.k)