# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

from . import BaseWrapperDataset


class PrependDataset(BaseWrapperDataset):
    def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
        super().__init__(dataset)
        self.prepend_getter = prepend_getter
        self.ensure_first_token = ensure_first_token_is

    def __getitem__(self, idx):
        item = self.dataset[idx]
        is_tuple = isinstance(item, tuple)
        src = item[0] if is_tuple else item

        assert self.ensure_first_token is None or src[0] == self.ensure_first_token
        prepend_idx = self.prepend_getter(self.dataset, idx)
        assert isinstance(prepend_idx, int)
        src[0] = prepend_idx
        item = tuple((src,) + item[1:]) if is_tuple else src
        return item
