

"""BERT Style dataset."""

import numpy as np 
import torch 

from megatron_lm .megatron .global_vars import (
get_tokenizer ,
)
from megatron_lm .megatron .data .dataset_utils import (
get_samples_mapping ,
get_a_and_b_segments ,
truncate_segments ,
create_tokens_and_tokentypes ,
create_masked_lm_predictions 
)


class BertDataset (torch .utils .data .Dataset ):

    def __init__ (self ,name ,indexed_dataset ,data_prefix ,
    num_epochs ,max_num_samples ,masked_lm_prob ,
    max_seq_length ,short_seq_prob ,seed ,binary_head ):


        self .name =name 
        self .seed =seed 
        self .masked_lm_prob =masked_lm_prob 
        self .max_seq_length =max_seq_length 
        self .binary_head =binary_head 


        self .indexed_dataset =indexed_dataset 


        self .samples_mapping =get_samples_mapping (self .indexed_dataset ,
        data_prefix ,
        num_epochs ,
        max_num_samples ,
        self .max_seq_length -3 ,
        short_seq_prob ,
        self .seed ,
        self .name ,
        self .binary_head )


        tokenizer =get_tokenizer ()
        self .vocab_id_list =list (tokenizer .inv_vocab .keys ())
        self .vocab_id_to_token_dict =tokenizer .inv_vocab 
        self .cls_id =tokenizer .cls 
        self .sep_id =tokenizer .sep 
        self .mask_id =tokenizer .mask 
        self .pad_id =tokenizer .pad 

    def __len__ (self ):
        return self .samples_mapping .shape [0 ]

    def __getitem__ (self ,idx ):
        start_idx ,end_idx ,seq_length =self .samples_mapping [idx ]
        sample =[self .indexed_dataset [i ]for i in range (start_idx ,end_idx )]



        np_rng =np .random .RandomState (seed =((self .seed +idx )%2 **32 ))
        return build_training_sample (sample ,seq_length ,
        self .max_seq_length ,
        self .vocab_id_list ,
        self .vocab_id_to_token_dict ,
        self .cls_id ,self .sep_id ,
        self .mask_id ,self .pad_id ,
        self .masked_lm_prob ,np_rng ,
        self .binary_head )


def build_training_sample (sample ,
target_seq_length ,max_seq_length ,
vocab_id_list ,vocab_id_to_token_dict ,
cls_id ,sep_id ,mask_id ,pad_id ,
masked_lm_prob ,np_rng ,binary_head ):
    """Biuld training sample.

    Arguments:
        sample: A list of sentences in which each sentence is a list token ids.
        target_seq_length: Desired sequence length.
        max_seq_length: Maximum length of the sequence. All values are padded to
            this length.
        vocab_id_list: List of vocabulary ids. Used to pick a random id.
        vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
        cls_id: Start of example id.
        sep_id: Separator id.
        mask_id: Mask token id.
        pad_id: Padding token id.
        masked_lm_prob: Probability to mask tokens.
        np_rng: Random number genenrator. Note that this rng state should be
              numpy and not python since python randint is inclusive for
              the opper bound whereas the numpy one is exclusive.
    """

    if binary_head :

        assert len (sample )>1 
    assert target_seq_length <=max_seq_length 


    if binary_head :
        tokens_a ,tokens_b ,is_next_random =get_a_and_b_segments (sample ,
        np_rng )
    else :
        tokens_a =[]
        for j in range (len (sample )):
            tokens_a .extend (sample [j ])
        tokens_b =[]
        is_next_random =False 


    max_num_tokens =target_seq_length 
    truncated =truncate_segments (tokens_a ,tokens_b ,len (tokens_a ),
    len (tokens_b ),max_num_tokens ,np_rng )


    tokens ,tokentypes =create_tokens_and_tokentypes (tokens_a ,tokens_b ,
    cls_id ,sep_id )


    max_predictions_per_seq =masked_lm_prob *max_num_tokens 
    (tokens ,masked_positions ,masked_labels ,_ ,_ )=create_masked_lm_predictions (
    tokens ,vocab_id_list ,vocab_id_to_token_dict ,masked_lm_prob ,
    cls_id ,sep_id ,mask_id ,max_predictions_per_seq ,np_rng )


    tokens_np ,tokentypes_np ,labels_np ,padding_mask_np ,loss_mask_np =pad_and_convert_to_numpy (tokens ,tokentypes ,masked_positions ,
    masked_labels ,pad_id ,max_seq_length )

    train_sample ={
    'text':tokens_np ,
    'types':tokentypes_np ,
    'labels':labels_np ,
    'is_random':int (is_next_random ),
    'loss_mask':loss_mask_np ,
    'padding_mask':padding_mask_np ,
    'truncated':int (truncated )}
    return train_sample 


def pad_and_convert_to_numpy (tokens ,tokentypes ,masked_positions ,
masked_labels ,pad_id ,max_seq_length ):
    """Pad sequences and convert them to numpy."""


    num_tokens =len (tokens )
    padding_length =max_seq_length -num_tokens 
    assert padding_length >=0 ,f"num_tokens ({num_tokens }) is greater than " "max_seq_length ({max_seq_length})."
    assert len (tokentypes )==num_tokens 
    assert len (masked_positions )==len (masked_labels )


    filler =[pad_id ]*padding_length 
    tokens_np =np .array (tokens +filler ,dtype =np .int64 )
    tokentypes_np =np .array (tokentypes +filler ,dtype =np .int64 )


    padding_mask_np =np .array ([1 ]*num_tokens +[0 ]*padding_length ,
    dtype =np .int64 )


    labels =[-1 ]*max_seq_length 
    loss_mask =[0 ]*max_seq_length 
    for i in range (len (masked_positions )):
        assert masked_positions [i ]<num_tokens 
        labels [masked_positions [i ]]=masked_labels [i ]
        loss_mask [masked_positions [i ]]=1 
    labels_np =np .array (labels ,dtype =np .int64 )
    loss_mask_np =np .array (loss_mask ,dtype =np .int64 )

    return tokens_np ,tokentypes_np ,labels_np ,padding_mask_np ,loss_mask_np 
