# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from mmengine.fileio import load

from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset


@DATASETS.register_module()
class VGVQA(BaseDataset):
    """Visual Genome VQA dataset."""

    def load_data_list(self) -> List[dict]:
        """Load data list.

        Compare to BaseDataset, the only difference is that coco_vqa annotation
        file is already a list of data. There is no 'metainfo'.
        """

        raw_data_list = load(self.ann_file)
        if not isinstance(raw_data_list, list):
            raise TypeError(
                f'The VQA annotations loaded from annotation file '
                f'should be a dict, but got {type(raw_data_list)}!')

        # load and parse data_infos.
        data_list = []
        for raw_data_info in raw_data_list:
            # parse raw data information to target format
            data_info = self.parse_data_info(raw_data_info)
            if isinstance(data_info, dict):
                # For VQA tasks, each `data_info` looks like:
                # {
                #   "question_id": 986769,
                #   "question": "How many people are there?",
                #   "answer": "two",
                #   "image": "image/1.jpg",
                #   "dataset": "vg"
                # }

                # change 'image' key to 'img_path'
                # TODO: This process will be removed, after the annotation file
                # is preprocess.
                data_info['img_path'] = data_info['image']
                del data_info['image']

                if 'answer' in data_info:
                    # add answer_weight & answer_count, delete duplicate answer
                    if data_info['dataset'] == 'vqa':
                        answer_weight = {}
                        for answer in data_info['answer']:
                            if answer in answer_weight.keys():
                                answer_weight[answer] += 1 / len(
                                    data_info['answer'])
                            else:
                                answer_weight[answer] = 1 / len(
                                    data_info['answer'])

                        data_info['answer'] = list(answer_weight.keys())
                        data_info['answer_weight'] = list(
                            answer_weight.values())
                        data_info['answer_count'] = len(answer_weight)

                    elif data_info['dataset'] == 'vg':
                        data_info['answers'] = [data_info['answer']]
                        data_info['answer_weight'] = [0.2]
                        data_info['answer_count'] = 1

                data_list.append(data_info)

            else:
                raise TypeError(
                    f'Each VQA data element loaded from annotation file '
                    f'should be a dict, but got {type(data_info)}!')

        return data_list
