# Copyright (c) OpenMMLab. All rights reserved.
import re
from itertools import chain
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class VisualGenomeQA(BaseDataset):
    """Visual Genome Question Answering dataset.

    dataset structure: ::

        data_root
        ├── image
        │   ├── 1.jpg
        │   ├── 2.jpg
        │   └── ...
        └── question_answers.json

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``
            and ``question_file``.
        data_prefix (str): The directory of images. Defaults to ``"image"``.
        ann_file (str, optional): Annotation file path for training and
            validation. Defaults to ``"question_answers.json"``.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self,
                 data_root: str,
                 data_prefix: str = 'image',
                 ann_file: str = 'question_answers.json',
                 **kwarg):
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def _create_image_index(self):
        img_prefix = self.data_prefix['img_path']

        files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
        image_index = {}
        for file in files:
            image_id = re.findall(r'\d+', file)
            if len(image_id) > 0:
                image_id = int(image_id[-1])
                image_index[image_id] = mmengine.join_path(img_prefix, file)

        return image_index

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        annotations = mmengine.load(self.ann_file)

        # The original Visual Genome annotation file and question file includes
        # only image id but no image file paths.
        self.image_index = self._create_image_index()

        data_list = []
        for qas in chain.from_iterable(ann['qas'] for ann in annotations):
            # ann example
            # {
            #     'id': 1,
            #     'qas': [
            #         {
            #             'a_objects': [],
            #             'question': 'What color is the clock?',
            #             'image_id': 1,
            #             'qa_id': 986768,
            #             'answer': 'Two.',
            #             'q_objects': [],
            #         }
            #         ...
            #     ]
            # }

            data_info = {
                'img_path': self.image_index[qas['image_id']],
                'quesiton': qas['quesiton'],
                'question_id': qas['question_id'],
                'image_id': qas['image_id'],
                'gt_answer': [qas['answer']],
            }

            data_list.append(data_info)

        return data_list
