# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import BaseWrapperDataset


class ReplaceDataset(BaseWrapperDataset):
    """Replaces tokens found in the dataset by a specified replacement token

    Args:
        dataset (~torch.utils.data.Dataset): dataset to replace tokens in
        replace_map(Dictionary[int,int]): map of token to replace -> replacement token
        offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
        as many as the number of objects returned by the underlying dataset __getitem__ method.
    """

    def __init__(self, dataset, replace_map, offsets):
        super().__init__(dataset)
        assert len(replace_map) > 0
        self.replace_map = replace_map
        self.offsets = offsets

    def __getitem__(self, index):
        item = self.dataset[index]
        is_tuple = isinstance(item, tuple)
        srcs = item if is_tuple else [item]

        for offset, src in zip(self.offsets, srcs):
            for k, v in self.replace_map.items():
                src_off = src[offset:] if offset >= 0 else src[:offset]
                src_off.masked_fill_(src_off == k, v)

        item = srcs if is_tuple else srcs[0]
        return item
