"""ImageNet-X.

https://facebookresearch.github.io/imagenetx/site/dataset
https://arxiv.org/pdf/2211.01866v1.pdf
https://github.com/facebookresearch/imagenetx/
"""
import dataclasses

import imagenet_x
import numpy as np

from em import datasets as em_datasets
from . import imagenet


###############################################################################

_SPLIT = "validation"

###############################################################################

ANNOTATION_TYPES = (
    'multiple_objects',
    'background',
    'color',
    'brighter',
    'darker',
    'style',
    'larger',
    'smaller',
    'object_blocking',
    'person_blocking',
    'partial_view',
    'pattern',
    'pose',
    'shape',
    'subcategory',
    'texture',
)


###############################################################################

@dataclasses.dataclass
class ImageNetXAnnotations:

    def __post_init__(self):
        self._annotations = imagenet_x.load_annotations()
        self._make_example_index_maps()

    def _make_example_index_maps(self):
        my_filenames = imagenet.get_image_filenames(_SPLIT)
        my_filename_to_index = {f: i for i, f in enumerate(my_filenames)}

        x_filenames = self._annotations['file_name'].tolist()
        x_filename_to_index = {f: i for i, f in enumerate(x_filenames)}

        # print(len(my_filenames), len(x_filenames))
        # assert len(my_filenames) == len(x_filenames)

        self.x_to_my_index_map = np.array([
            my_filename_to_index.get(f, -1) for i, f in enumerate(x_filenames)
        ], dtype=np.int32)
        self.my_to_x_index_map = np.array([
            x_filename_to_index.get(f, -1) for i, f in enumerate(my_filenames)
        ], dtype=np.int32)


###############################################################################
