# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.

import numpy as np
import torch

from megatron.training import get_args, get_tokenizer


class BertEmbeddingDataset(torch.utils.data.Dataset):
    '''Dataset to convert a text dataset to Bert tokens.'''

    def __init__(self, text_dataset, max_seq_length):

        super().__init__()

        args = get_args()

        # Dataset, tokenizer.
        self.text_dataset = text_dataset
        self.max_seq_length = max_seq_length
        self.bert_tokenizer = get_tokenizer()

    def __len__(self):
        return len(self.text_dataset)

    @classmethod
    def build_sample(cls, tokenizer, token_ids):
        get_constant_array = lambda c : np.full((len(token_ids) + 2,), c, "int64")
        return {
            "text" : np.array([ tokenizer.cls, *token_ids, tokenizer.sep ], dtype="int64"),
            "types" : get_constant_array(0),
            "labels" : get_constant_array(-1),
            "is_random" : 0,
            "loss_mask" : get_constant_array(0),
            "padding_mask" : get_constant_array(1),
            "truncated" : 0,
        }

    def __getitem__(self, idx):

        # Text.
        text_sample = self.text_dataset[idx]
        text = text_sample["text"]
        text = text.replace("<|endoftext|>", "")

        # Bert/Wordpiece tokens (+truncate).
        bert_token_ids = self.bert_tokenizer.tokenize(text)
        bert_token_ids = bert_token_ids[:self.max_seq_length - 2] # cls+sep.
        if not bert_token_ids:
            bert_token_ids = [ self.bert_tokenizer.pad_id ] # hack when empty seq

        # Bert sample.
        sample = self.build_sample(self.bert_tokenizer, bert_token_ids)

        return sample
