from mmdet.registry import DATASETS
from mmdet.datasets import DIORDataset
import os.path as osp
from typing import List
from mmdet.registry import DATASETS
from .xml_style import XMLDataset

@DATASETS.register_module()
class DiorIncDataset(DIORDataset):
    def __init__(self, 
                 *args, 
                 start: int = 0, 
                 end: int = 20, 
                #  setting: str = '_text',
                 **kwargs):
        self.start = start
        self.end = end
        # self.setting = setting
        super().__init__(*args, **kwargs)

        # 设置原始类与新类
        self._metainfo['start'] = start
        self._metainfo['ori_classes'] = self.metainfo['classes'][:start]
        self._metainfo['new_classes'] = self.metainfo['classes'][start:end]

    METAINFO = {
        'classes': ('airplane', 'airport', 'baseballfield', 'basketballcourt',
        'bridge', 'chimney', 'dam', 'Expressway-Service-area', 'Expressway-toll-station',
        'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank',
        'tenniscourt', 'trainstation', 'vehicle', 'windmill')
    }

    def load_data_list(self) -> List[dict]:
        """Load filtered annotations based on start/end."""
        data_list = super().load_data_list()

        # 构造类别索引映射
        # if self.setting == 'cur_text':
        #     class_range = list(range(self.start, self.end))
        #     class_name_range = self.metainfo['classes'][self.start:self.end]
        # else:
        class_range = list(range(0, self.end))
        class_name_range = self.metainfo['classes'][:self.end]

        filtered_list = []
        for data_info in data_list:
            valid_instances = []
            for ins in data_info['instances']:
                if ins['bbox_label'] in class_range:
                    valid_instances.append(ins)
            if len(valid_instances) > 0:
                data_info['instances'] = valid_instances

                # 为 GroundingDINO 加入 text 和 prompt 信息
                if self.return_classes:
                    data_info['custom_entities'] = True
                    data_info['text'] = class_name_range
                    data_info['ori_text'] = self.metainfo['classes'][:self.start]

                filtered_list.append(data_info)

        return filtered_list
