import torch
import numpy as np
from tqdm import tqdm
from dlisa.util.utils import get_batch_aabb_pair_ious
from dlisa.evaluation.general_evaluator import GeneralEvaluator


IOU_THRESHOLD = 0.9  # referit3d uses GT boxes, a dummy iou here


class ReferIt3DEvaluator(GeneralEvaluator):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.evaluation_types = {"easy_dep": 0, "easy_indep": 1, "hard_dep": 2, "hard_indep": 3}
        self.evaluation_types_comb = {"easy": (0, 1), "hard": (2, 3), "view_dep": (0, 2), "view_indep": (1, 3)}

    def _print_results(self, results):
        print(f"{'=' * 55}")
        print("{0:<12}{1:<12}{2:<12}{3:<12}{4:<12}".format("easy", "hard", "view-dep", "view-indep", "overall"))
        print(f"{'-' * 55}")
        line_1_str = ''
        for sub_group_type, score in results.items():
            line_1_str += '{:<12.1f}'.format(score * 100)
        print(line_1_str)
        print(f"{'=' * 55}")

    def evaluate(self, predictions):
        all_gt_info_len = len(self.ground_truths)
        eval_type_mask = np.empty(all_gt_info_len, dtype=np.uint8)
        tps = np.zeros(all_gt_info_len, dtype=bool)
        iterator = enumerate(tqdm(predictions.items(), desc="Evaluating") if self.verbose else predictions.items())
        for i, (key, value) in iterator:
            eval_type_mask[i] = self.evaluation_types[self.ground_truths[key]["eval_type"]]
            tps[i] = self._evaluate_one_query(value, self.ground_truths[key])
        results = {}
        for sub_group in self.evaluation_types_comb.keys():
            selected_indices = np.isin(eval_type_mask, np.array(self.evaluation_types_comb[sub_group], dtype=np.uint8))
            if np.any(selected_indices):
                results[sub_group] = np.count_nonzero(tps[selected_indices]) / np.count_nonzero(selected_indices)
            else:
                results[sub_group] = np.nan
        results["overall"] = np.count_nonzero(tps) / tps.shape[0]

        if self.verbose:
            self._print_results(results)

        return {self.metric_name: results}

    def _evaluate_one_query(self, pred_info, gt_info):
        # initialize true positives
        tp = 0

        # TODO: convert to batch process
        iou = get_batch_aabb_pair_ious(
            torch.from_numpy(pred_info["aabb_bound"]), torch.from_numpy(gt_info["aabb_bound"])
        )[0].item()
        if iou >= IOU_THRESHOLD:
            tp += 1
        return tp
