"""A class to collect and evaluate language grounding results."""

import numpy as np
import torch

from .classifiers import RelClassifier
from models.losses import _iou3d_par, box_cxcyczwhd_to_xyzxyz
from src.scannet_classes import REL_ALIASES, SCANNET_OBJECTS

import ipdb
st = ipdb.set_trace


class GroundingEvaluator:
    """
    Evaluate language grounding.

    Args:
        only_root (bool): detect only the root noun
        thresholds (list): IoU thresholds to check
        topks (list): k to evaluate top--k accuracy
        prefixes (list): names of layers to evaluate
    """

    def __init__(self, only_root=True, thresholds=[0.25, 0.5],
                 topks=[1, 5, 10], prefixes=[], use_detected_boxes=False,
                 evaluate_viewpoint=False, train_viewpoint_prototype=False,
                 run_on_target_phrases=False):
        """Initialize accumulators."""
        self.only_root = only_root
        self.thresholds = thresholds
        self.topks = topks
        self.prefixes = prefixes
        self.use_detected_boxes = use_detected_boxes
        self.evaluate_viewpoint = evaluate_viewpoint
        self.train_viewpoint_prototype = train_viewpoint_prototype
        self.run_on_target_phrases = run_on_target_phrases
        self.reset()

    def reset(self):
        """Reset accumulators to empty."""
        self.dets = {
            (prefix, t, k, mode): 0
            for prefix in self.prefixes
            for t in self.thresholds
            for k in self.topks
            for mode in ['bbs', 'bbf', 'opt']
        }
        self.dets.update({
            (prefix, k, mode): 0
            for prefix in self.prefixes
            for k in self.topks
            for mode in ['sbb', 'fbb']
        })
        self.dets.update({
            ('per_relation', t, k, mode, rel): 0
            for t in self.thresholds
            for k in self.topks
            for mode in ['bbs', 'bbf', 'opt']
            for rel in list(set(list(REL_ALIASES.values()))) + ['none']
        })
        self.dets.update({
            ('per_target', t, k, mode, name): 0
            for t in self.thresholds
            for k in self.topks
            for mode in ['bbs', 'bbf', 'opt']
            for name in SCANNET_OBJECTS + ['none']
        })
        self.gts = dict(self.dets)

        if self.evaluate_viewpoint:
            self.pred_viewpoints = {
                prefix: 0
                for prefix in self.prefixes
            }
            self.gt_viewpoints = dict(self.pred_viewpoints)
        self.dets.update({'yy': 0, 'ny': 0})
        self.gts.update({'y': 1e-14, 'n': 1e-14})
        self.dets.update({'yy_t1': 0, 'ny_t1': 0})
        self.gts.update({'y_t1': 1e-14, 'n_t1': 1e-14})
        self.dets.update({'grounder_improved_iou': 0, 'detected_iou': 0, 'found_iou': 0})

    def print_stats(self):
        """Print accumulated accuracies."""
        mode_str = {
            'bbs': 'Box given span (soft-token)',
            'bbf': 'Box given span (contrastive)',
            'opt': 'Box given span (optimistic)',
            'sbb': 'Span given box (soft-token)',
            'fbb': 'Span given box (contrastive)'
        }
        for prefix in self.prefixes:
            # for mode in ['bbs', 'bbf', 'opt']:
            for mode in ['bbs']:
                for t in self.thresholds:
                    print(
                        prefix, mode_str[mode], 'Acc%.2f:' % t,
                        ', '.join([
                            'Top-%d: %.3f' % (
                                k,
                                self.dets[(prefix, t, k, mode)]
                                / max(self.gts[(prefix, t, k, mode)], 1)
                            )
                            for k in self.topks
                        ])
                    )
            # for mode in ['sbb', 'fbb']:
            for mode in []:
                print(
                    prefix, mode_str[mode], 'Acc:',
                    ', '.join([
                        'Top-%d: %.3f' % (
                            k,
                            self.dets[(prefix, k, mode)]
                            / max(self.gts[(prefix, k, mode)], 1)
                        )
                        for k in self.topks
                    ])
                )
        """
        print('\nPer relation accuracy')
        for mode in ['bbs', 'bbf', 'opt']:
            for t in self.thresholds:
                for rel in set(list(REL_ALIASES.values())):
                    if self.gts[('per_relation', t, 1, mode, rel)] == 0:
                        continue
                    print(
                        rel,
                        mode_str[mode], 'Acc%.2f:' % t,
                        ', '.join([
                            'Top-%d: %.3f' % (
                                k,
                                self.dets[('per_relation', t, k, mode, rel)]
                                / self.gts[('per_relation', t, k, mode, rel)]
                            )
                            for k in self.topks
                            if self.gts[('per_relation', t, k, mode, rel)]
                        ])
                    )
        """
        print('\nRecall')
        print('Detected and found:', self.dets['yy'] / self.gts['y'])
        print('Not detected but found:', self.dets['ny'] / self.gts['n'])
        print('Grounder Recall:', (self.dets['yy'] + self.dets['ny']) / (self.gts['y'] + self.gts['n']))
        print('Detector Recall:', self.gts['y'] / (self.gts['y'] + self.gts['n']))
        print('\nRecall top-1')
        print('Detected and found:', self.dets['yy_t1'] / self.gts['y_t1'])
        print('Not detected but found:', self.dets['ny_t1'] / self.gts['n_t1'])
        print('Grounder Recall:', (self.dets['yy_t1'] + self.dets['ny_t1']) / (self.gts['y_t1'] + self.gts['n_t1']))
        print('\nIou comparisons')
        print('mIoU detector:', self.dets['detected_iou'] / self.gts['y'])
        print('mIoU grounder:', self.dets['found_iou'] / self.gts['y'])
        print('grounder was better:', self.dets['grounder_improved_iou'] / self.gts['y'])

        if self.evaluate_viewpoint:
            print('\nViewpoint Accuracy')
            for prefix in self.prefixes:
                acc = self.pred_viewpoints[prefix] / self.gt_viewpoints[prefix]
                print(f"{prefix}_viewpoint: {acc}")

    def evaluate(self, end_points, prefix):
        """
        Evaluate all accuracies.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # self.evaluate_bbox_by_span_opt(end_points, prefix)
        self.evaluate_bbox_by_span(end_points, prefix)
        """
        self.evaluate_bbox_by_contrast(end_points, prefix)
        self.evaluate_span_by_bbox(end_points, prefix)
        self.evaluate_contrast_by_bbox(end_points, prefix)
        """

        if self.evaluate_viewpoint:
            self.evaluate_viewpoint_bins(end_points, prefix)
        if prefix.startswith('last'):
            self.evaluate_recall(end_points, prefix)

    def evaluate_recall(self, end_points, prefix):
        """
        Evaluate whether we can find the correct box.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        gt_center = end_points['center_label'][:, :, 0:3]  # (B, K, 3)
        gt_size = end_points['size_gts']  # (B, K2,3)
        gt_bboxes = torch.cat([gt_center, gt_size], dim=-1)
        pred_center = end_points[f'{prefix}center']  # B, Q, 3
        pred_size = end_points[f'{prefix}pred_size']  # (B,Q,3) (l,w,h)
        assert (pred_size < 0).sum() == 0
        pred_bbox = torch.cat([pred_center, pred_size], dim=-1)
        for bid in range(len(gt_bboxes)):
            # IoU
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][end_points['box_label_mask'][bid].bool()]),  # (gt, 6)
                box_cxcyczwhd_to_xyzxyz(pred_bbox[bid])  # (Q, 6)
            )  # (gt, Q)
            found = (ious[:1] > 0.25).any(1).all() * 1.0
            found_iou = ious[:1].max()
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][end_points['box_label_mask'][bid].bool()]),  # (gt, 6)
                box_cxcyczwhd_to_xyzxyz(end_points['all_detected_boxes'][bid][end_points['all_detected_bbox_label_mask'][bid].bool()])  # (Q, 6)
            )  # (gt, Q)
            detected = (ious[:1] > 0.25).any(1).all()
            detected_iou = ious[:1].max()
            if detected:
                self.gts['y'] += 1
                self.dets['yy'] += found.item()
                self.dets['found_iou'] += found_iou
                self.dets['detected_iou'] += detected_iou
                self.dets['grounder_improved_iou'] += found_iou >= detected_iou
            else:
                self.gts['n'] += 1
                self.dets['ny'] += found.item()

    def evaluate_viewpoint_bins(self, end_points, prefix):
        if f'{prefix}pred_viewpoint' in end_points:
            gt_viewpoint = end_points['eul'].detach().cpu().numpy()
            bins = np.linspace(-180, 180, 12)
            gt_viewpoint = np.degrees(gt_viewpoint)
            gt_bins = np.digitize(gt_viewpoint, bins)
            pred_viewpoint = end_points[f'{prefix}pred_viewpoint']
            if not self.train_viewpoint_prototype:
                pred_viewpoint = pred_viewpoint.detach().cpu().numpy()
                pred_viewpoint = np.degrees(pred_viewpoint)
                pred_bins = np.digitize(pred_viewpoint, bins)
                correct = (gt_bins == pred_bins).all(1).sum()
            else:
                pred_viewpoint = pred_viewpoint.softmax(-1).detach().cpu().numpy()
                top = pred_viewpoint.argmax(-1)
                correct = (top == gt_bins[..., 2]).sum()
            self.pred_viewpoints[prefix] += correct
            self.gt_viewpoints[prefix] += gt_bins.shape[0]
        else:
            self.pred_viewpoints[prefix] += 0
            self.gt_viewpoints[prefix] += 1

    def evaluate_bbox_by_span_opt(self, end_points, prefix):
        """
        Evaluate bounding box IoU for any span detections.

        Same protocol as RefCOCO evaluator.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        _, gt_bboxes = self._parse_gt(end_points)

        # Parse predictions
        if self.use_detected_boxes:
            pred_bbox = end_points['all_bboxes']
            mask = ~end_points['query_mask']
        else:
            pred_center = end_points[f'{prefix}center']  # B, Q, 3
            pred_size = end_points[f'{prefix}pred_size']  # (B,Q,3) (l,w,h)
            assert (pred_size < 0).sum() == 0
            pred_bbox = torch.cat([pred_center, pred_size], dim=-1)
        # nonempty_box_mask = self._detect_empty_boxes(end_points, pred_bbox)

        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)
        # sem_scores = sem_scores * nonempty_box_mask.unsqueeze(-1)

        # Highest scoring box -> iou
        for bid in range(len(gt_bboxes)):
            # Scores consider any span now
            scores = sem_scores[bid][:, :-1].sum(-1)  # (Q,)

            if self.use_detected_boxes:
                scores = scores * mask[bid]

            # 10 predictions per gt box
            top = scores.argsort(0, True)[:10]  # (10,)
            pbox = pred_bbox[bid, top.reshape(-1)]  # (10, 6)

            # IoU
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][:1]),  # (1, 6)
                box_cxcyczwhd_to_xyzxyz(pbox)  # (10, 6)
            )  # (1, 10)
            ious = ious.reshape(-1)

            # Measure IoU>threshold, ious are (10,)
            rel = end_points['relation'][bid]
            for t in self.thresholds:
                thresholded = ious > t
                for k in self.topks:
                    found = thresholded[:k].any().sum().item()
                    self.dets[(prefix, t, k, 'opt')] += found
                    self.gts[(prefix, t, k, 'opt')] += 1
                    if prefix == 'last_':
                        self.dets[('per_relation', t, k, 'opt', rel)] += found
                        self.gts[('per_relation', t, k, 'opt', rel)] += 1

    def evaluate_bbox_by_span(self, end_points, prefix):
        """
        Evaluate bounding box IoU for top gt span detections.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, gt_bboxes = self._parse_gt(end_points)

        # Parse predictions
        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)

        if sem_scores.shape[-1] != positive_map.shape[-1]:
            sem_scores_ = torch.zeros(
                sem_scores.shape[0], sem_scores.shape[1],
                positive_map.shape[-1]).to(sem_scores.device)
            sem_scores_[:, :, :sem_scores.shape[-1]] = sem_scores
            sem_scores = sem_scores_

        # Parse predictions
        if self.use_detected_boxes:
            pred_bbox = end_points['all_bboxes']
            mask = ~end_points['query_mask']
        else:
            pred_center = end_points[f'{prefix}center']  # B, Q, 3
            pred_size = end_points[f'{prefix}pred_size']  # (B,Q,3) (l,w,h)
            assert (pred_size < 0).sum() == 0
            pred_bbox = torch.cat([pred_center, pred_size], dim=-1)
        # nonempty_box_mask = self._detect_empty_boxes(end_points, pred_bbox)
        # sem_scores = sem_scores * nonempty_box_mask.unsqueeze(-1)

        if prefix == 'last_':
            all_pboxes = torch.zeros((len(positive_map), 10, 6))

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            is_correct = None
            if False:  # this works only for the target box now
                ious, _ = _iou3d_par(
                    box_cxcyczwhd_to_xyzxyz(
                        end_points['all_detected_boxes'][bid][end_points['all_detected_bbox_label_mask'][bid]]
                    ),  # (gt, 6)
                    box_cxcyczwhd_to_xyzxyz(pred_bbox[bid])  # (Q, 6)
                )  # (gt, Q)
                # matches = ious.argmax(0)  # (Q,)
                is_correct = (ious.max(0)[0] > 0.25) * 1.0
                # classes = end_points['all_detected_class_ids'][bid][matches]  # (Q,)
                # is_correct = (classes == end_points['target_cid'][bid]).float()
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]
            scores = (
                sem_scores[bid].unsqueeze(0)  # (1, Q, 256)
                * pmap.unsqueeze(1)  # (obj, 1, 256)
            ).sum(-1)  # (obj, Q)
            if is_correct is not None:
                scores = scores * is_correct[None]

            if self.use_detected_boxes:
                scores = scores * mask[bid].unsqueeze(0)

            # 10 predictions per gt box
            top = scores.argsort(1, True)[:, :10]  # (obj, 10)
            pbox = pred_bbox[bid, top.reshape(-1)]
            if prefix == 'last_':
                all_pboxes[bid] = pbox[:10]

            # IoU
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][:num_obj]),  # (obj, 6)
                box_cxcyczwhd_to_xyzxyz(pbox)  # (obj*10, 6)
            )  # (obj, obj*10)
            ious = ious.reshape(top.size(0), top.size(0), top.size(1))
            ious = ious[torch.arange(len(ious)), torch.arange(len(ious))]

            # Measure IoU>threshold, ious are (obj, 10)
            rel = end_points['relation'][bid]
            topks = (
                self.topks if not self.run_on_target_phrases
                else [(end_points['distractor_ids'][bid] > -1).sum().item()]
            )
            for t in self.thresholds:
                thresholded = ious > t
                for k in topks:
                    found = thresholded[:, :k].any(1)
                    if self.run_on_target_phrases:
                        k = 1
                    self.dets[(prefix, t, k, 'bbs')] += found.sum().item()
                    self.gts[(prefix, t, k, 'bbs')] += len(thresholded)
                    if prefix == 'last_':
                        found = found[0].item()
                        self.dets[('per_relation', t, k, 'bbs', rel)] += found
                        self.gts[('per_relation', t, k, 'bbs', rel)] += 1
                        if k == 1 and t == self.thresholds[0]:
                            if end_points['detector_succeeded'][bid]:
                                self.gts['y_t1'] += 1
                                self.dets['yy_t1'] += found
                            else:
                                self.gts['n_t1'] += 1
                                self.dets['ny_t1'] += found

        if prefix == 'last_':
            end_points[f'{prefix}pred_boxes_bbs'] = all_pboxes

    def evaluate_bbox_by_contrast(self, end_points, prefix):
        """
        Evaluate bounding box IoU by contrasting with span features.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, gt_bboxes = self._parse_gt(end_points)

        # Parse predictions
        if self.use_detected_boxes:
            pred_bbox = end_points['all_bboxes']
            mask = ~end_points['query_mask']
        else:
            pred_center = end_points[f'{prefix}center']  # B, Q, 3
            pred_size = end_points[f'{prefix}pred_size']  # (B,Q,3) (l,w,h)
            assert (pred_size < 0).sum() == 0
            pred_bbox = torch.cat([pred_center, pred_size], dim=-1)
        # nonempty_box_mask = self._detect_empty_boxes(end_points, pred_bbox)

        proj_tokens = end_points['proj_tokens']  # (B, tokens, 64)
        proj_queries = end_points[f'{prefix}proj_queries']  # (B, Q, 64)
        sem_scores = torch.matmul(proj_queries, proj_tokens.transpose(-1, -2))
        sem_scores_ = (sem_scores / 0.07).softmax(-1)  # (B, Q, tokens)
        sem_scores = torch.zeros(sem_scores_.size(0), sem_scores_.size(1), 256)
        sem_scores = sem_scores.to(sem_scores_.device)
        sem_scores[:, :sem_scores_.size(1), :sem_scores_.size(2)] = sem_scores_
        if prefix == 'last_':
            all_pboxes = torch.zeros((len(positive_map), 10, 6))
        # sem_scores = sem_scores * nonempty_box_mask.unsqueeze(-1)

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]
            scores = (
                sem_scores[bid].unsqueeze(0)  # (1, Q, 256)
                * pmap.unsqueeze(1)  # (obj, 1, 256)
            ).sum(-1)  # (obj, Q)

            if self.use_detected_boxes:
                scores = scores * mask[bid].unsqueeze(0)

            # 10 predictions per gt box
            top = scores.argsort(1, True)[:, :10]  # (obj, 10)
            pbox = pred_bbox[bid, top.reshape(-1)]
            if prefix == 'last_':
                all_pboxes[bid] = pbox[:10]

            # IoU
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][:num_obj]),  # (obj, 6)
                box_cxcyczwhd_to_xyzxyz(pbox)  # (obj*10, 6)
            )  # (obj, obj*10)
            ious = ious.reshape(top.size(0), top.size(0), top.size(1))
            ious = ious[torch.arange(len(ious)), torch.arange(len(ious))]

            # Measure IoU>threshold, ious are (obj, 10)
            rel = end_points['relation'][bid]
            for t in self.thresholds:
                thresholded = ious > t
                for k in self.topks:
                    found = thresholded[:, :k].any(1)
                    self.dets[(prefix, t, k, 'bbf')] += found.sum().item()
                    self.gts[(prefix, t, k, 'bbf')] += len(thresholded)
                    if prefix == 'last_':
                        found = found[0].item()
                        self.dets[('per_relation', t, k, 'bbf', rel)] += found
                        self.gts[('per_relation', t, k, 'bbf', rel)] += 1

        if prefix == 'last_':
            end_points[f'{prefix}pred_boxes_bbf'] = all_pboxes

    def evaluate_span_by_bbox(self, end_points, prefix):
        """
        Evaluate span accuracy given gt box.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, gt_bboxes = self._parse_gt(end_points)

        # Parse predictions
        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)
        if self.use_detected_boxes:
            pred_bbox = end_points['all_bboxes']
            mask = ~end_points['query_mask']
        else:
            pred_center = end_points[f'{prefix}center']  # B, Q, 3
            pred_size = end_points[f'{prefix}pred_size']  # (B,Q,3) (l,w,h)
            assert (pred_size < 0).sum() == 0
            pred_bbox = torch.cat([pred_center, pred_size], dim=-1)
        # nonempty_box_mask = self._detect_empty_boxes(end_points, pred_bbox)
        # sem_scores = sem_scores * nonempty_box_mask.unsqueeze(-1)

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]  # (obj, 256)
            scores = sem_scores[bid]  # (Q, 256)

            if self.use_detected_boxes:
                scores = scores * mask[bid].unsqueeze(1)

            # IoU
            pbox = pred_bbox[bid]  # (Q, 6)
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][:num_obj]),  # (obj, 6)
                box_cxcyczwhd_to_xyzxyz(pbox)  # (Q, 6)
            )  # (obj, Q)

            # 10 predictions per gt box
            top = ious.argsort(1, True)[:, :10]  # (obj, 10)

            # Measure accuracy of span prediction
            _scores = torch.stack([
                scores[top[s]] for s in range(len(top))
            ])  # (obj, 10, 256)
            _scores = torch.stack([
                pmap[s, _scores[s].argmax(-1)] for s in range(len(top))
            ])  # (obj, 10)
            for k in self.topks:
                self.dets[(prefix, k, 'sbb')] += \
                    (_scores[:, :k] > 0).any(1).sum().item()
                self.gts[(prefix, k, 'sbb')] += len(_scores)

    def evaluate_contrast_by_bbox(self, end_points, prefix):
        """
        Evaluate feature contrast accuracy given gt box.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, gt_bboxes = self._parse_gt(end_points)

        # Parse predictions
        if self.use_detected_boxes:
            pred_bbox = end_points['all_bboxes']
            mask = ~end_points['query_mask']
        else:
            pred_center = end_points[f'{prefix}center']  # B, Q, 3
            pred_size = end_points[f'{prefix}pred_size']  # (B,Q,3) (l,w,h)
            assert (pred_size < 0).sum() == 0
            pred_bbox = torch.cat([pred_center, pred_size], dim=-1)
        # nonempty_box_mask = self._detect_empty_boxes(end_points, pred_bbox)

        proj_tokens = end_points['proj_tokens']  # (B, tokens, 64)
        proj_queries = end_points[f'{prefix}proj_queries']  # (B, Q, 64)
        sem_scores = torch.matmul(proj_queries, proj_tokens.transpose(-1, -2))
        sem_scores_ = (sem_scores / 0.07).softmax(-1)  # (B, Q, tokens)
        sem_scores = torch.zeros(sem_scores_.size(0), sem_scores_.size(1), 256)
        sem_scores = sem_scores.to(sem_scores_.device)
        sem_scores[:, :sem_scores_.size(1), :sem_scores_.size(2)] = sem_scores_
        # sem_scores = sem_scores * nonempty_box_mask.unsqueeze(-1)

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]  # (obj, 256)
            scores = sem_scores[bid]  # (Q, 256)

            if self.use_detected_boxes:
                scores = scores * mask[bid].unsqueeze(1)

            # IoU
            pbox = pred_bbox[bid]  # (Q, 6)
            ious, _ = _iou3d_par(
                box_cxcyczwhd_to_xyzxyz(gt_bboxes[bid][:num_obj]),  # (obj, 6)
                box_cxcyczwhd_to_xyzxyz(pbox)  # (Q, 6)
            )  # (obj, Q)

            # 10 predictions per gt box
            top = ious.argsort(1, True)[:, :10]  # (obj, 10)

            # Measure accuracy of span prediction
            _scores = torch.stack([
                scores[top[s]] for s in range(len(top))
            ])  # (obj, 10, 256)
            _scores = torch.stack([
                pmap[s, _scores[s].argmax(-1)] for s in range(len(top))
            ])
            for k in self.topks:
                self.dets[(prefix, k, 'fbb')] += \
                    (_scores[:, :k] > 0).any(1).sum().item()
                self.gts[(prefix, k, 'fbb')] += len(_scores)

    def _parse_gt(self, end_points):
        positive_map = torch.clone(end_points['positive_map'])  # (B, K, 256)
        positive_map[positive_map > 0] = 1
        gt_center = end_points['center_label'][:, :, 0:3]  # (B, K, 3)
        gt_size = end_points['size_gts']  # (B, K2,3)
        gt_bboxes = torch.cat([gt_center, gt_size], dim=-1)  # cxcyczwhd
        if self.only_root:
            positive_map = positive_map[:, :1]  # (B, 1, 256)
            gt_bboxes = gt_bboxes[:, :1]  # (B, 1, 6)
        return positive_map, gt_bboxes

    def _detect_empty_boxes(self, end_points, pred_bbox):
        pred_bbox = torch.cat((
            pred_bbox[:, :, :3] - pred_bbox[:, :, 3:] / 2,
            pred_bbox[:, :, :3] + pred_bbox[:, :, 3:] / 2
        ), -1).detach()  # (B, K, 6)
        batch_pc = end_points['point_clouds'].detach()[:, :, :3]  # (B, N, 3)
        nonempty_box_mask = (
            (batch_pc.unsqueeze(2) - pred_bbox.unsqueeze(1)[..., :3] >= 0)
            & (batch_pc.unsqueeze(2) - pred_bbox.unsqueeze(1)[..., 3:] <= 0)
        ).all(-1)  # (B, N, K)
        return (nonempty_box_mask.long().sum(1) > 5) * 1  # (B, K)


class GroundingEvaluatorGTBoxes:
    """
    Evaluate language grounding.

    Args:
        only_root (bool): detect only the root noun
        thresholds (list): IoU thresholds to check
        topks (list): k to evaluate top--k accuracy
        prefixes (list): names of layers to evaluate
    """

    def __init__(self, only_root=True,
                 topks=[1, 5, 10], prefixes=[], gt_variant=False,
                 apply_classifiers=False):
        """Initialize accumulators."""
        self.only_root = only_root
        self.topks = topks
        self.prefixes = prefixes
        self.gt_variant = gt_variant
        self.apply_classifiers = apply_classifiers
        self.on_classifer = RelClassifier()
        self.on_classifer.load_state_dict(torch.load(
            "./dataset/language_grounding/on_classifier.pt"
        ), strict=False)
        self.reset()

    def reset(self):
        """Reset accumulators to empty."""
        self.dets = {
            (prefix, k, mode): 0
            for prefix in self.prefixes
            for k in self.topks
            for mode in ['bbs', 'bbf', 'opt']
        }
        self.dets.update({
            (prefix, mode): 0
            for prefix in self.prefixes
            for mode in ['sbb', 'fbb']
        })
        self.dets.update({
            ('per_relation', k, mode, rel): 0
            for k in self.topks
            for mode in ['bbs', 'bbf', 'opt']
            for rel in set(list(REL_ALIASES.values()))
        })
        self.dets.update({
            ('per_target', k, mode, name): 0
            for k in self.topks
            for mode in ['bbs', 'bbf', 'opt']
            for name in SCANNET_OBJECTS
        })
        self.gts = dict(self.dets)

    def print_stats(self):
        """Print accumulated accuracies."""
        mode_str = {
            'bbs': 'Box given span (soft-token)',
            'bbf': 'Box given span (contrastive)',
            'opt': 'Box given span (optimistic)',
            'sbb': 'Span given box (soft-token)',
            'fbb': 'Span given box (contrastive)'
        }
        for prefix in self.prefixes:
            for mode in ['bbs', 'bbf', 'opt']:
                print(
                    prefix, mode_str[mode], 'Acc:',
                    ', '.join([
                        'Top-%d: %.3f' % (
                            k,
                            self.dets[(prefix, k, mode)]
                            / self.gts[(prefix, k, mode)]
                        )
                        for k in self.topks
                        if self.gts[(prefix, k, mode)]
                    ])
                )
            for mode in ['sbb', 'fbb']:
                print(
                    prefix, mode_str[mode], 'Acc:',
                    ', '.join([
                        'Top: %.3f' % (
                            self.dets[(prefix, mode)]
                            / self.gts[(prefix, mode)]
                        )
                        if self.gts[(prefix, mode)] else '0'
                    ])
                )
        print('\nPer relation accuracy')
        for mode in ['bbs', 'bbf', 'opt']:
            for rel in set(list(REL_ALIASES.values())):
                if self.gts[('per_relation', 1, mode, rel)] == 0:
                    continue
                print(
                    rel,
                    mode_str[mode], 'Acc:',
                    ', '.join([
                        'Top-%d: %.3f' % (
                            k,
                            self.dets[('per_relation', k, mode, rel)]
                            / self.gts[('per_relation', k, mode, rel)]
                        )
                        for k in self.topks
                        if self.gts[('per_relation', k, mode, rel)]
                    ])
                )
        '''
        print('\nPer target class accuracy')
        for mode in ['bbs', 'bbf', 'opt']:
            for name in SCANNET_OBJECTS:
                if self.gts[('per_target', 1, mode, name)] == 0:
                    continue
                print(
                    name,
                    mode_str[mode], 'Acc:',
                    ', '.join([
                        'Top-%d: %.3f' % (
                            k,
                            self.dets[('per_target', k, mode, name)]
                            / self.gts[('per_target', k, mode, name)]
                        )
                        for k in self.topks
                        if self.gts[('per_target', k, mode, name)]
                    ])
                )
        '''

    def evaluate(self, end_points, prefix):
        """
        Evaluate all accuracies.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        if self.gt_variant:
            self.evaluate_bbox_by_span_opt(end_points, prefix)
            return None
        if self.apply_classifiers:
            self.evaluate_bbox_by_contrast(end_points, prefix)
            return None
        self.evaluate_bbox_by_span_opt(end_points, prefix)
        self.evaluate_bbox_by_span(end_points, prefix)
        self.evaluate_bbox_by_contrast(end_points, prefix)
        self.evaluate_span_by_bbox(end_points, prefix)
        self.evaluate_contrast_by_bbox(end_points, prefix)

    def evaluate_gt_boxes(self, end_points):
        # only for debugging
        positive_map, _ = self._parse_gt(end_points)
        assert not torch.sum(positive_map[:, 2])

        gt_center = end_points['center_label'][:, :, 0:3]  # (B, K, 3)
        gt_size = end_points['size_gts']  # (B, K2,3)
        pred_bbox = torch.cat([gt_center, gt_size], dim=-1)  # cxcyczwhd

        assert pred_bbox.shape[0] == 1

        # hack for finding which of target or prediction
        # is supposed to be "on" the other one. Ideally,
        # we would get this from parser

        # A on B
        if pred_bbox[0, 0, 2] > pred_bbox[0, 1, 2]:
            ind_a, ind_b = 0, 1
        else:
            ind_a, ind_b = 1, 0

        found = on_classifier(
            pred_bbox[0, ind_a].cpu().numpy(),
            pred_bbox[0, ind_b].cpu().numpy()
        )

        prefix = 'last_'
        k = 1
        self.dets[(prefix, k, 'bbf')] += int(found)
        self.gts[(prefix, k, 'bbf')] += 1

    def evaluate_bbox_by_span_opt(self, end_points, prefix):
        """
        Evaluate bounding box IoU for any span detections.

        Same protocol as RefCOCO evaluator.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse predictions
        if prefix == 'proposal_' and self.gt_variant:
            return None
        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)
        mask = ~end_points['query_mask']

        # Highest scoring box -> iou
        for bid in range(len(sem_scores)):
            # Scores consider any span now
            if not self.gt_variant:
                scores = sem_scores[bid][:, :-1].sum(-1)  # (Q,)
                scores = scores * mask[bid]
            else:
                scores = sem_scores[bid]  # (Q,)

            # 10 predictions per gt box
            top = scores.argsort(0, True)[:10]  # (10,)

            target = end_points['target_id'][bid]

            # Measure IoU>threshold, ious are (10,)
            rel = end_points['relation'][bid]

            for k in self.topks:
                found = 1 if target in top[:k] else 0
                self.dets[(prefix, k, 'opt')] += found
                self.gts[(prefix, k, 'opt')] += 1
                if prefix == 'last_':
                    self.dets[('per_relation', k, 'opt', rel)] += found
                    self.gts[('per_relation', k, 'opt', rel)] += 1

    def evaluate_bbox_by_span(self, end_points, prefix):
        """
        Evaluate bounding box IoU for top gt span detections.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, _ = self._parse_gt(end_points)
        gt_center = end_points['center_label'][:, :, 0:3]  # (B, K, 3)
        gt_size = end_points['size_gts']  # (B, K2,3)
        pred_bbox = torch.cat([gt_center, gt_size], dim=-1)  # cxcyczwhd

        # Parse predictions
        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)
        mask = ~end_points['query_mask']

        if sem_scores.shape[-1] != positive_map.shape[-1]:
            sem_scores_ = torch.zeros(
                sem_scores.shape[0], sem_scores.shape[1],
                positive_map.shape[-1]).to(sem_scores.device)
            sem_scores_[:, :, :sem_scores.shape[-1]] = sem_scores
            sem_scores = sem_scores_
        if prefix == 'last_':
            all_pboxes = torch.zeros((len(positive_map), 10, 6))

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]
            scores = (
                sem_scores[bid].unsqueeze(0)  # (1, Q, 256)
                * pmap.unsqueeze(1)  # (obj, 1, 256)
            ).sum(-1)  # (obj, Q)

            scores = scores * mask[bid].unsqueeze(0)

            # 10 predictions per gt box
            top = scores.argsort(1, True)[:, :10]  # (obj, 10)
            target = end_points['target_id'][bid]
            pbox = pred_bbox[bid, top.reshape(-1)]
            if prefix == 'last_':
                all_pboxes[bid] = pbox[:10]

            # Measure IoU>threshold, ious are (obj, 10)
            rel = end_points['relation'][bid]
            for k in self.topks:
                found = (top[0, :k] == target)
                self.dets[(prefix, k, 'bbs')] += found.sum().item()
                self.gts[(prefix, k, 'bbs')] += 1
                if prefix == 'last_':
                    found = found[0].item()
                    self.dets[('per_relation', k, 'bbs', rel)] += found
                    self.gts[('per_relation', k, 'bbs', rel)] += 1

        if prefix == 'last_':
            end_points[f'{prefix}pred_boxes_bbs'] = all_pboxes

    def evaluate_bbox_by_contrast(self, end_points, prefix):
        """
        Evaluate bounding box IoU by contrasting with span features.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, _ = self._parse_gt(end_points)
        gt_center = end_points['center_label'][:, :, 0:3]  # (B, K, 3)
        gt_size = end_points['size_gts']  # (B, K2,3)
        pred_bbox = torch.cat([gt_center, gt_size], dim=-1)  # cxcyczwhd
        all_bboxes = end_points['all_bboxes']  # (B, 132, 6)

        proj_tokens = end_points['proj_tokens']  # (B, tokens, 64)
        proj_queries = end_points[f'{prefix}proj_queries']  # (B, Q, 64)
        sem_scores = torch.matmul(proj_queries, proj_tokens.transpose(-1, -2))
        sem_scores_ = (sem_scores / 0.07).softmax(-1)  # (B, Q, tokens)
        sem_scores = torch.zeros(sem_scores_.size(0), sem_scores_.size(1), 256)
        sem_scores = sem_scores.to(sem_scores_.device)
        sem_scores[:, :sem_scores_.size(1), :sem_scores_.size(2)] = sem_scores_
        mask = ~end_points['query_mask']

        if prefix == 'last_':
            all_pboxes = torch.zeros((len(positive_map), 10, 6))
            all_aboxes = torch.zeros((len(positive_map), 10, 6))

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]
            scores = (
                sem_scores[bid].unsqueeze(0)  # (1, Q, 256)
                * pmap.unsqueeze(1)  # (obj, 1, 256)
            ).sum(-1)  # (obj, Q)
            scores = scores * mask[bid].unsqueeze(0)

            # 10 predictions per gt box
            top = scores.argsort(1, True)[:, :10]  # (obj, 10)
            target = end_points['target_id'][bid]
            pbox = all_bboxes[bid, top.reshape(-1)]  # (obj*10, 6)
            if prefix == 'last_':
                all_pboxes[bid] = pbox[:10]
                all_aboxes[bid] = pbox[10:20]

            if self.apply_classifiers:
                t_boxes = pbox.reshape(len(top), 10, -1)[0]
                a_boxes = pbox.reshape(len(top), 10, -1)[1]
                t_boxes = t_boxes.unsqueeze(1).repeat(1, 10, 1).reshape(-1, 6)
                a_boxes = a_boxes.unsqueeze(0).repeat(10, 1, 1).reshape(-1, 6)
                # 10, 10
                scores_on = self.on_classifer.to(t_boxes.device)(
                    box_cxcyczwhd_to_xyzxyz(t_boxes).unsqueeze(1),
                    box_cxcyczwhd_to_xyzxyz(a_boxes).unsqueeze(1)
                ).reshape(10, 10).sigmoid()
                
                top_forward_scores = scores.sort(dim=1, descending=True)[0][:, :10]
                top_forward_scores_pairs = top_forward_scores[0].unsqueeze(0) * \
                                                    top_forward_scores[1].unsqueeze(1)
                scores_final = (top_forward_scores_pairs * scores_on).max(1)[0]  # (10,)
                # combine top 10 forward model scores and on classifier score
                top_ = scores_final.argsort(descending=True)
                
                # re-rank original top 10
                top = torch.gather(top[0], 0, top_).unsqueeze(0)
                '''
                # hack for finding which of target or prediction
                # is supposed to be "on" the other one. Ideally,
                # we would get this from parser
                # A on B
                if pred_bbox[bid, 0, 2] > pred_bbox[bid, 1, 2]:
                    ind_a, ind_b = 0, 1
                else:
                    ind_a, ind_b = 1, 0

                for i in range(5):
                    found = False
                    pred_bbox[bid, 0] = all_bboxes[bid, top[0, i]]
                    for j in range(5):
                        pred_bbox[bid, 1] = all_bboxes[bid, top[1, j]]
                        found = on_classifier(
                            pred_bbox[bid, ind_a].cpu().numpy(),
                            pred_bbox[bid, ind_b].cpu().numpy()
                        )
                        if found:
                            temp = copy.deepcopy(top[0, 0])
                            top[0, 0] = top[0, i]
                            top[0, i] = temp
                            break
                    if found:
                        break
                '''

            # Measure IoU>threshold, ious are (obj, 10)
            rel = end_points['relation'][bid]
            all_found = []
            for k in self.topks:
                found = (top[0, :k] == target)
                self.dets[(prefix, k, 'bbf')] += found.sum().item()
                self.gts[(prefix, k, 'bbf')] += 1
                if prefix == 'last_':
                    found = found[0].item()
                    self.dets[('per_relation', k, 'bbf', rel)] += found
                    self.gts[('per_relation', k, 'bbf', rel)] += 1
                    if k == 1:
                        all_found.append(found)

        if prefix == 'last_':
            end_points[f'{prefix}pred_boxes_bbf'] = all_pboxes
            end_points['anchor_estimates'] = all_aboxes
            end_points['found'] = found

    def evaluate_span_by_bbox(self, end_points, prefix):
        """
        Evaluate span accuracy given gt box.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, _ = self._parse_gt(end_points)

        # Parse predictions
        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]  # (obj, 256)
            scores = sem_scores[bid]  # (Q, 256)
            target = end_points['target_id'][bid]

            _score = scores[target]
            _score = pmap[:, _score.argmax(-1)]

            self.dets[(prefix, 'sbb')] += _score > 0
            self.gts[(prefix, 'sbb')] += 1

    def evaluate_contrast_by_bbox(self, end_points, prefix):
        """
        Evaluate feature contrast accuracy given gt box.

        Args:
            end_points (dict): contains predictions and gt
            prefix (str): layer name
        """
        # Parse gt
        positive_map, _ = self._parse_gt(end_points)

        proj_tokens = end_points['proj_tokens']  # (B, tokens, 64)
        proj_queries = end_points[f'{prefix}proj_queries']  # (B, Q, 64)
        sem_scores = torch.matmul(proj_queries, proj_tokens.transpose(-1, -2))
        sem_scores_ = (sem_scores / 0.07).softmax(-1)  # (B, Q, tokens)
        sem_scores = torch.zeros(sem_scores_.size(0), sem_scores_.size(1), 256)
        sem_scores = sem_scores.to(sem_scores_.device)
        sem_scores[:, :sem_scores_.size(1), :sem_scores_.size(2)] = sem_scores_

        # Highest scoring box -> iou
        for bid in range(len(positive_map)):
            # Keep scores for annotated objects only
            num_obj = int(end_points['box_label_mask'][bid].sum())
            pmap = positive_map[bid, :num_obj]  # (obj, 256)
            scores = sem_scores[bid]  # (Q, 256)
            target = end_points['target_id'][bid]

            _score = scores[target]
            _score = pmap[:, _score.argmax(-1)]

            self.dets[(prefix, 'fbb')] += _score > 0
            self.gts[(prefix, 'fbb')] += 1

    def _parse_gt(self, end_points):
        positive_map = torch.clone(end_points['positive_map'])  # (B, K, 256)
        positive_map[positive_map > 0] = 1
        gt_center = end_points['center_label'][:, :, 0:3]  # (B, K, 3)
        gt_size = end_points['size_gts']  # (B, K2,3)
        gt_bboxes = torch.cat([gt_center, gt_size], dim=-1)  # cxcyczwhd
        if self.only_root and not self.apply_classifiers:
            positive_map = positive_map[:, :1]  # (B, 1, 256)
            gt_bboxes = gt_bboxes[:, :1]  # (B, 1, 6)
        return positive_map, gt_bboxes
