# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import re
from collections import Counter
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class COCOVQA(BaseDataset):
    """VQAv2 dataset.

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``
            and ``question_file``.
        data_prefix (str): The directory of images.
        question_file (str): Question file path.
        ann_file (str, optional): Annotation file path for training and
            validation. Defaults to an empty string.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self,
                 data_root: str,
                 data_prefix: str,
                 question_file: str,
                 ann_file: str = '',
                 **kwarg):
        self.question_file = question_file
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def _join_prefix(self):
        if not mmengine.is_abs(self.question_file) and self.question_file:
            self.question_file = osp.join(self.data_root, self.question_file)

        return super()._join_prefix()

    def _create_image_index(self):
        img_prefix = self.data_prefix['img_path']

        files = mmengine.list_dir_or_file(img_prefix, list_dir=False)
        image_index = {}
        for file in files:
            image_id = re.findall(r'\d{12}', file)
            if len(image_id) > 0:
                image_id = int(image_id[-1])
                image_index[image_id] = mmengine.join_path(img_prefix, file)

        return image_index

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        questions = mmengine.load(self.question_file)['questions']
        if self.ann_file:
            annotations = mmengine.load(self.ann_file)['annotations']
            assert len(questions) == len(annotations)
        else:
            annotations = [None] * len(questions)

        # The original VQAv2 annotation file and question file includes
        # only image id but no image file paths.
        self.image_index = self._create_image_index()

        data_list = []
        for question, ann in zip(questions, annotations):
            # question example
            # {
            #     'image_id': 262144,
            #     'question': "Is the ball flying towards the batter?",
            #     'question_id': 262144000
            # }
            #
            # ann example
            # {
            #     'question_type': "what are the",
            #     'answer_type': "other",
            #     'answers': [
            #         {'answer': 'watching',
            #          'answer_id': 1,
            #          'answer_confidence': 'yes'},
            #         ...
            #     ],
            #     'image_id': 262148,
            #     'question_id': 262148000,
            #     'multiple_choice_answer': 'watching',
            #     'answer_type': 'other',
            # }

            data_info = question
            data_info['img_path'] = self.image_index[question['image_id']]

            if ann is not None:
                assert ann['question_id'] == question['question_id']

                # add answer_weight & answer_count, delete duplicate answer
                answers = [item['answer'] for item in ann.pop('answers')]
                count = Counter(answers)
                answer_weight = [i / len(answers) for i in count.values()]
                data_info['gt_answer'] = list(count.keys())
                data_info['gt_answer_weight'] = answer_weight
                data_info.update(ann)

            data_list.append(data_info)

        return data_list
