"""Obverter dataset."""

import os

import tensorflow as tf
import tensorflow_datasets.public_api as tfds


# Obverter constants
# pylint: disable=line-too-long
_URL = 'https://raw.githubusercontent.com/benbogin/obverter/master/assets/dataset.tar.gz'

_CITATION = """\
@article{choi2018compositional,
  title={Compositional obverter communication learning from raw visual input},
  author={Choi, Edward and Lazaridou, Angeliki and de Freitas, Nando},
  journal={arXiv preprint arXiv:1804.02341},
  year={2018}
}
"""

class Obverter(tfds.core.GeneratorBasedBuilder):
    """Obverter."""
    URL = _URL

    VERSION = tfds.core.Version('0.0.1')

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=('The Obverter database of handwritten digits.'),
            features=tfds.features.FeaturesDict({
                'image': tfds.features.Image(shape=(128, 128, 3)),
                'label_color': tfds.features.ClassLabel(num_classes=8),
                'label_shape': tfds.features.ClassLabel(num_classes=5),
            }),
            homepage='https://github.com/benbogin/obverter/blob/master',
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        filepath = dl_manager.download_and_extract(_URL)

        # There is no predefined train/val/test split for this dataset.
        return [
            tfds.core.SplitGenerator(
                name=tfds.Split.TRAIN,
                gen_kwargs=dict(data_path=filepath)),
        ]

    def _generate_examples(self, data_path):
        """Generate Obverter examples as dicts.
        Args:
            data_path (str): Path to the data files
        Yields:
            Generator yielding the next examples
        """
        colors = set()
        shapes = set()
        image_labels = []
        for entry in tf.io.gfile.listdir(data_path):
            color, shape = entry.split('-')[:2]
            colors.add(color)
            shapes.add(shape)
            image_labels.append({
                'color': color,
                'shape': shape,
                'filename': entry,
                })

        colors2labels = {color: idx for idx, color in enumerate(sorted(colors))}
        shapes2labels = {shape: idx for idx, shape in enumerate(sorted(shapes))}

        image_labels = [{
            'filename': os.path.join(data_path, d['filename']),
            'label_color': colors2labels[d['color']],
            'label_shape': shapes2labels[d['shape']]}
            for d in image_labels]

        for index, features in enumerate(image_labels):
            image = tf.io.read_file(features['filename'])
            image = tf.image.decode_png(image)
            record = {'image': image.numpy(),
                      'label_color': features['label_color'],
                      'label_shape': features['label_shape'],
                     }
            yield index, record
