from mmdet.registry import METRICS

from typing import Sequence

from mmdet.evaluation.metrics import CocoMetric
from mmengine.fileio import dump


@METRICS.register_module()
class MissingPersonMetric(CocoMetric):
    """Missing Person Metric for COCO dataset.

    Args:
        ann_file (str): Annotation file path.
        proposal_file (str, optional): Proposal file path. Default: None.
        metric (str or list[str], optional): Metrics to be evaluated. Default: 'bbox'.
        format_only (bool, optional): Whether to only format the results. Default: False.
        logger (logging.Logger | str | None, optional): Logger used for logging. Default: None.
        **kwargs: Other keyword arguments.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    
    def results2json(self, results: Sequence[dict],
                     outfile_prefix: str) -> dict:
        """Dump the detection results to a COCO style json file.

        There are 3 types of results: proposals, bbox predictions, mask
        predictions, and they have different data types. This method will
        automatically recognize the type, and dump them to json files.

        Args:
            results (Sequence[dict]): Testing results of the
                dataset.
            outfile_prefix (str): The filename prefix of the json files. If the
                prefix is "somepath/xxx", the json files will be named
                "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
                "somepath/xxx.proposal.json".

        Returns:
            dict: Possible keys are "bbox", "segm", "proposal", and
            values are corresponding filenames.
        """
        bbox_json_results = []
        segm_json_results = [] if 'masks' in results[0] else None
        for idx, result in enumerate(results):
            image_id = result.get('img_id', idx)
            labels = result['labels']
            bboxes = result['bboxes']
            scores = result['scores']
            # bbox results
            for i, label in enumerate(labels):
                data = dict()
                data['image_id'] = image_id
                data['bbox'] = self.xyxy2xywh(bboxes[i])
                data['score'] = float(scores[i])
                # data['category_id'] = self.cat_ids[label]
                # bbox_json_results.append(data)
                try:
                    data['category_id'] = self.cat_ids[label]
                    bbox_json_results.append(data)
                except IndexError:
                    # if the label is not in the coco dataset, skip it
                    # this may happen when the dataset is not coco
                    # or the label is not in the coco dataset
                    continue

            if segm_json_results is None:
                continue

            # segm results
            masks = result['masks']
            mask_scores = result.get('mask_scores', scores)
            for i, label in enumerate(labels):
                data = dict()
                data['image_id'] = image_id
                data['bbox'] = self.xyxy2xywh(bboxes[i])
                data['score'] = float(mask_scores[i])
                data['category_id'] = self.cat_ids[label]
                if isinstance(masks[i]['counts'], bytes):
                    masks[i]['counts'] = masks[i]['counts'].decode()
                data['segmentation'] = masks[i]
                segm_json_results.append(data)

        result_files = dict()
        result_files['bbox'] = f'{outfile_prefix}.bbox.json'
        result_files['proposal'] = f'{outfile_prefix}.bbox.json'
        dump(bbox_json_results, result_files['bbox'])

        if segm_json_results is not None:
            result_files['segm'] = f'{outfile_prefix}.segm.json'
            dump(segm_json_results, result_files['segm'])

        return result_files
