"""Wrappers for tf.data.Datasets"""

import gin
import numpy as np
import tensorflow as tf


@gin.configurable
class ReferenceGameDatasetWrapper:
    """Wrapper for a tf.data.Dataset adapting it to a reference game setting:
    The sender sees a target image, communicates it to the receiver. Then the
    receiver sees `num_distractors`+1 candidates images and must indicate the
    target image.

    The iterator returns three tensors:
    1. `candidates`, a (batch_size, num_distractors + 1, ...)-dimensional tensor
        with distractors and the target image stacked along axis 1
    2. `target` image, (batch_size, ...)-dimensional tensor
    3. `label`, a (batch_size,)-dimensional tensor of indices of target images
    in the `candidates` tensor

    Notes: the wrapper inherits `batch_size` from the underlying wrapper. It
    discards all the original labels. If the division
    num_batches/num_candidates has a remainder, those batches are dropped.
    """

    def __init__(
            self,
            dataset: tf.data.Dataset,
            num_distractors: int
    ):
        self.iterator = iter(dataset)
        self.num_distractors = num_distractors
        self.num_candidates = num_distractors + 1

    def __iter__(self):
        return self

    def __next__(self):
        candidates = tf.stack(
            [next(self.iterator)['image'] for _ in range(self.num_candidates)],
            axis=1
        )
        batch_size = candidates.shape[0]
        label = np.random.randint(self.num_candidates, size=batch_size)
        indices = np.stack([np.arange(batch_size), label], axis=1)
        target = tf.gather_nd(candidates, indices)
        return candidates, target, label
