# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class TextVQA(BaseDataset):
    """TextVQA dataset.

    val image:
        https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
    test image:
        https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
    val json:
        https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
    test json:
        https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json

    folder structure:
    data/textvqa
        ├── annotations
        │   ├── TextVQA_0.5.1_test.json
        │   └── TextVQA_0.5.1_val.json
        └── images
            ├── test_images
            └── train_images

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``
            and ``question_file``.
        data_prefix (str): The directory of images.
        question_file (str): Question file path.
        ann_file (str, optional): Annotation file path for training and
            validation. Defaults to an empty string.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self,
                 data_root: str,
                 data_prefix: str,
                 ann_file: str = '',
                 **kwarg):
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        annotations = mmengine.load(self.ann_file)['data']

        data_list = []

        for ann in annotations:

            # ann example
            # {
            #     'question': 'what is the brand of...is camera?',
            #     'image_id': '003a8ae2ef43b901',
            #     'image_classes': [
            #         'Cassette deck', 'Printer', ...
            #         ],
            #     'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
            #     'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
            #     'image_width': 1024,
            #     'image_height': 664,
            #     'answers': [
            #         'nous les gosses',
            #         'dakota',
            #         'clos culombu',
            #         'dakota digital' ...
            #        ],
            #     'question_tokens':
            #         ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
            #     'question_id': 34602,
            #     'set_name': 'val'
            # }

            data_info = dict(question=ann['question'])
            data_info['question_id'] = ann['question_id']
            data_info['image_id'] = ann['image_id']

            img_path = mmengine.join_path(self.data_prefix['img_path'],
                                          ann['image_id'] + '.jpg')
            data_info['img_path'] = img_path

            data_info['question_id'] = ann['question_id']

            if 'answers' in ann:
                answers = [item for item in ann.pop('answers')]
                count = Counter(answers)
                answer_weight = [i / len(answers) for i in count.values()]
                data_info['gt_answer'] = list(count.keys())
                data_info['gt_answer_weight'] = answer_weight

            data_list.append(data_info)

        return data_list
