import mmengine

from mmdet.datasets.base_det_dataset import BaseDetDataset
from mmdet.registry import DATASETS

from pycocotools.coco import COCO
from mmengine.fileio import get_local_path

from typing import List, Union
import os.path as osp


@DATASETS.register_module()
class MissingPersonDataset(BaseDetDataset):

    METAINFO = {
        'classes': ('person'),
        'palette': [(220, 20, 60)]
    }

    def __init__(self,
                 *args,
                 **kwargs) -> None:
        self.selected_visible_ratio = kwargs.pop('selected_visible_ratio', None)
        self.selected_pose = kwargs.pop('selected_pose', None)
        self.selected_season = kwargs.pop('selected_season', None)
        self.selected_weather = kwargs.pop('selected_weather', None)
        self.selected_place = kwargs.pop('selected_place', None)
        self.selected_condition = {
            'visible_ratio': self.selected_visible_ratio,
            'pose': self.selected_pose,
            'season': self.selected_season,
            'weather': self.selected_weather
        }
        super().__init__(*args, **kwargs)


    def filter_annotations(self, annotation_instances, condition) -> dict:
        new_parsed_data_info_instances = []
        # old_parsed_data_info_instances = parsed_data_info['instances']
        old_parsed_data_info_instances = annotation_instances
        for instance in old_parsed_data_info_instances:
            # if instance['attributes'][condition] in self.selected_condition[condition]:
            if instance['attributes'][condition] == self.selected_condition[condition]:
                new_parsed_data_info_instances.append(instance)                
        return new_parsed_data_info_instances

    def load_data_list(self):

        with get_local_path(self.ann_file, backend_args=self.backend_args) as local_path:
            coco_obj = COCO(local_path)

        cats = coco_obj.loadCats(coco_obj.getCatIds())
        self.mapping_category_to_supercategory = {cat['id']: cat['supercategory'] for cat in cats}
        self.cat_ids = coco_obj.getCatIds(self.metainfo['classes'])
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}

        
        img_ids = coco_obj.getImgIds()

        data_list = []
        total_ann_ids = []

        print(len(img_ids))
        
        for img_id in img_ids:
            raw_img_info = coco_obj.loadImgs(img_id)[0]
            raw_img_info['img_id'] = img_id

            ann_ids = coco_obj.getAnnIds(imgIds=img_id)
            raw_ann_info = coco_obj.loadAnns(ann_ids)
            total_ann_ids.extend(ann_ids)

            parsed_data_info = self.parse_data_info(
                dict(raw_img_info=raw_img_info, raw_ann_info=raw_ann_info))
            
            # breakpoint()
            
            old_parsed_data_info_instances = parsed_data_info['instances']
            
            # Filter annotations based on selected attributes
            # if self.selected_condition['visible_ratio'] is not None:
            #     new_parsed_data_info_instances = self.filter_annotations(parsed_data_info['instances'], 'visible_ratio')
            #     parsed_data_info['instances'] = new_parsed_data_info_instances
            # if self.selected_condition['pose'] is not None:
            #     new_parsed_data_info_instances = self.filter_annotations(parsed_data_info['instances'], 'pose')
            #     parsed_data_info['instances'] = new_parsed_data_info_instances
            if self.selected_condition['season'] is not None:
                new_parsed_data_info_instances = self.filter_annotations(parsed_data_info['instances'], 'season')
                parsed_data_info['instances'] = new_parsed_data_info_instances
            # if self.selected_condition['weather'] is not None:
            #     new_parsed_data_info_instances = self.filter_annotations(parsed_data_info['instances'], 'weather')
            #     parsed_data_info['instances'] = new_parsed_data_info_instances

            if len(old_parsed_data_info_instances) != 0 and len(parsed_data_info['instances']) == 0:
                continue
            data_list.append(parsed_data_info)
        print(len(data_list))
        assert len(set(total_ann_ids)) == len(total_ann_ids), f"Annotation ids in '{self.ann_file}' are not unique!"
        return data_list

    def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
        """Parse raw annotation to target format.

        Args:
            raw_data_info (dict): Raw data information load from ``ann_file``

        Returns:
            Union[dict, List[dict]]: Parsed annotation.
        """
        img_info = raw_data_info['raw_img_info']
        ann_info = raw_data_info['raw_ann_info']

        data_info = {}

        # TODO: need to change data_prefix['img'] to data_prefix['img_path']
        img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
        if self.data_prefix.get('seg', None):
            seg_map_path = osp.join(
                self.data_prefix['seg'],
                img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
        else:
            seg_map_path = None
        data_info['img_path'] = img_path
        data_info['img_id'] = img_info['img_id']
        data_info['seg_map_path'] = seg_map_path
        data_info['height'] = img_info['height']
        data_info['width'] = img_info['width']

        if self.return_classes:
            data_info['text'] = self.metainfo['classes']
            data_info['caption_prompt'] = self.caption_prompt
            data_info['custom_entities'] = True

        instances = []
        for i, ann in enumerate(ann_info):
            instance = {}

            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
            inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
            if inter_w * inter_h == 0:
                continue
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            if ann['category_id'] not in self.cat_ids:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]

            if ann.get('iscrowd', False):
                instance['ignore_flag'] = 1
            else:
                instance['ignore_flag'] = 0
            instance['bbox'] = bbox
            instance['bbox_label'] = self.cat2label[ann['category_id']]

            if ann.get('segmentation', None):
                instance['mask'] = ann['segmentation']

            if ann.get("attributes", None):
                attributes = ann['attributes']
                camera = attributes.get('camera', None)
                weather = attributes.get('weather', None)
                season = attributes.get('season', None)
                place = attributes.get('place', None)
                visible_ratio = attributes.get('visible_ratio', None)
                pose = attributes.get('pose', None)
                instance['attributes'] = {
                    'camera': camera,
                    'weather': weather,
                    'season': season,
                    'place': place,
                    'visible_ratio': visible_ratio,
                    'pose': pose
                }
            instances.append(instance)
        data_info['instances'] = instances
        return data_info
    
    # def filter_data(self) -> List[dict]:
    #     """Filter annotations according to filter_cfg.

    #     Returns:
    #         List[dict]: Filtered results.
    #     """
    #     if self.test_mode:
    #         return self.data_list

    #     if self.filter_cfg is None:
    #         return self.data_list

    #     filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
    #     min_size = self.filter_cfg.get('min_size', 0)

    #     # obtain images that contain annotation
    #     ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
    #     # obtain images that contain annotations of the required categories
    #     ids_in_cat = set()
    #     for i, class_id in enumerate(self.cat_ids):
    #         ids_in_cat |= set(self.cat_img_map[class_id])
    #     # merge the image id sets of the two conditions and use the merged set
    #     # to filter out images if self.filter_empty_gt=True
    #     ids_in_cat &= ids_with_ann

    #     valid_data_infos = []
    #     for i, data_info in enumerate(self.data_list):
    #         img_id = data_info['img_id']
    #         width = data_info['width']
    #         height = data_info['height']
    #         if filter_empty_gt and img_id not in ids_in_cat:
    #             continue
    #         if min(width, height) >= min_size:
    #             valid_data_infos.append(data_info)

    #     return valid_data_infos