# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class VizWiz(BaseDataset):
    """VizWiz dataset.

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``
            and ``question_file``.
        data_prefix (str): The directory of images.
        ann_file (str, optional): Annotation file path for training and
            validation. Defaults to an empty string.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self,
                 data_root: str,
                 data_prefix: str,
                 ann_file: str = '',
                 **kwarg):
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        annotations = mmengine.load(self.ann_file)

        data_list = []
        for ann in annotations:
            # {
            #     "image": "VizWiz_val_00000001.jpg",
            #     "question": "Can you tell me what this medicine is please?",
            #     "answers": [
            #     {
            #         "answer": "no",
            #         "answer_confidence": "yes"
            #     },
            #     {
            #         "answer": "unanswerable",
            #         "answer_confidence": "yes"
            #     },
            #     {
            #         "answer": "night time",
            #         "answer_confidence": "maybe"
            #     },
            #     {
            #         "answer": "unanswerable",
            #         "answer_confidence": "yes"
            #     },
            #     {
            #         "answer": "night time",
            #         "answer_confidence": "maybe"
            #     },
            #     {
            #         "answer": "night time cold medicine",
            #         "answer_confidence": "maybe"
            #     },
            #     {
            #         "answer": "night time",
            #         "answer_confidence": "maybe"
            #     },
            #     {
            #         "answer": "night time",
            #         "answer_confidence": "maybe"
            #     },
            #     {
            #         "answer": "night time",
            #         "answer_confidence": "maybe"
            #     },
            #     {
            #         "answer": "night time medicine",
            #         "answer_confidence": "yes"
            #     }
            #     ],
            #     "answer_type": "other",
            #     "answerable": 1
            # },
            data_info = dict()
            data_info['question'] = ann['question']
            data_info['img_path'] = mmengine.join_path(
                self.data_prefix['img_path'], ann['image'])

            if 'answerable' not in ann:
                data_list.append(data_info)
            else:
                if ann['answerable'] == 1:
                    # add answer_weight & answer_count, delete duplicate answer
                    answers = []
                    for item in ann.pop('answers'):
                        if item['answer_confidence'] == 'yes' and item[
                                'answer'] != 'unanswerable':
                            answers.append(item['answer'])
                    count = Counter(answers)
                    answer_weight = [i / len(answers) for i in count.values()]
                    data_info['gt_answer'] = list(count.keys())
                    data_info['gt_answer_weight'] = answer_weight
                    # data_info.update(ann)
                    data_list.append(data_info)

        return data_list
