# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import List

import mmengine
from mmengine import get_file_backend

from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset


@DATASETS.register_module()
class Flickr30kRetrieval(BaseDataset):
    """Flickr30k Retrieval dataset.

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``
            and ``question_file``.
        data_prefix (str): The directory of images.
        ann_file (str): Annotation file path for training and validation.
        split (str): 'train', 'val' or 'test'.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self, data_root: str, data_prefix: str, ann_file: str,
                 split: str, **kwarg):

        assert split in ['train', 'val', 'test'], \
            '`split` must be train, val or test'
        self.split = split
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        # get file backend
        img_prefix = self.data_prefix['img_path']
        file_backend = get_file_backend(img_prefix)

        annotations = mmengine.load(self.ann_file)

        # mapping img_id to img filename
        img_dict = OrderedDict()
        img_idx = 0
        sentence_idx = 0
        train_list = []
        for img in annotations['images']:

            # img_example={
            #     "sentids": [0, 1, 2],
            #     "imgid": 0,
            #     "sentences": [
            #         {"raw": "Two men in green shirts standing in a yard.",
            #          "imgid": 0, "sentid": 0},
            #         {"raw": "A man in a blue shirt standing in a garden.",
            #          "imgid": 0, "sentid": 1},
            #         {"raw": "Two friends enjoy time spent together.",
            #          "imgid": 0, "sentid": 2}
            #     ],
            #     "split": "train",
            #     "filename": "1000092795.jpg"
            # },

            if img['split'] != self.split:
                continue

            # create new idx for image
            train_image = dict(
                ori_id=img['imgid'],
                image_id=img_idx,  # used for evaluation
                img_path=file_backend.join_path(img_prefix, img['filename']),
                text=[],
                gt_text_id=[],
                gt_image_id=[],
            )

            for sentence in img['sentences']:
                ann = {}
                ann['text'] = sentence['raw']
                ann['ori_id'] = sentence['sentid']
                ann['text_id'] = sentence_idx  # used for evaluation

                ann['image_ori_id'] = train_image['ori_id']
                ann['image_id'] = train_image['image_id']
                ann['img_path'] = train_image['img_path']
                ann['is_matched'] = True

                # 1. prepare train data list item
                train_list.append(ann)
                # 2. prepare eval data list item based on img dict
                train_image['text'].append(ann['text'])
                train_image['gt_text_id'].append(ann['text_id'])
                train_image['gt_image_id'].append(ann['image_id'])

                sentence_idx += 1

            img_dict[img['imgid']] = train_image
            img_idx += 1

        self.img_size = len(img_dict)
        self.text_size = len(train_list)

        # return needed format data list
        if self.test_mode:
            return list(img_dict.values())
        return train_list
