# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class OCRVQA(BaseDataset):
    """OCR-VQA dataset.

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``
            and ``question_file``.
        data_prefix (str): The directory of images.
        ann_file (str): Annotation file path for training and validation.
        split (str): 'train', 'val' or 'test'.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self, data_root: str, data_prefix: str, ann_file: str,
                 split: str, **kwarg):

        assert split in ['train', 'val', 'test'], \
            '`split` must be train, val or test'
        self.split = split
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def load_data_list(self) -> List[dict]:
        """Load data list."""

        split_dict = {1: 'train', 2: 'val', 3: 'test'}

        annotations = mmengine.load(self.ann_file)

        # ann example
        # "761183272": {
        #     "imageURL": \
        #         "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg",
        #     "questions": [
        #         "Who wrote this book?",
        #         "What is the title of this book?",
        #         "What is the genre of this book?",
        #         "Is this a games related book?",
        #         "What is the year printed on this calendar?"],
        #     "answers": [
        #         "Sandra Boynton",
        #         "Mom's Family Wall Calendar 2016",
        #         "Calendars",
        #         "No",
        #         "2016"],
        #     "title": "Mom's Family Wall Calendar 2016",
        #     "authorName": "Sandra Boynton",
        #     "genre": "Calendars",
        #     "split": 1
        # },

        data_list = []

        for key, ann in annotations.items():
            if self.split != split_dict[ann['split']]:
                continue

            extension = osp.splitext(ann['imageURL'])[1]
            if extension not in ['.jpg', '.png']:
                continue
            img_path = mmengine.join_path(self.data_prefix['img_path'],
                                          key + extension)
            for question, answer in zip(ann['questions'], ann['answers']):
                data_info = {}
                data_info['img_path'] = img_path
                data_info['question'] = question
                data_info['gt_answer'] = answer
                data_info['gt_answer_weight'] = [1.0]

                data_info['imageURL'] = ann['imageURL']
                data_info['title'] = ann['title']
                data_info['authorName'] = ann['authorName']
                data_info['genre'] = ann['genre']

                data_list.append(data_info)

        return data_list
