



















import math 
import os 
import time 
import collections 
from networkx import trophic_differences 

import numpy as np 
import torch 
import torch .distributed as torch_distributed 

from megatron_lm .megatron .global_vars import (
get_args ,
)
from llama_recipes .utils .distributed import print_rank_0 
from megatron_lm .megatron .core .datasets .indexed_dataset import MMapIndexedDataset 


DSET_TYPE_BERT ='standard_bert'

DSET_TYPES =[DSET_TYPE_BERT ]


def get_datasets_weights_and_num_samples (
data_prefix :list [str ],
train_valid_test_num_samples :list [float ]|float ,
)->tuple [list [str ],list [float ],list [int ]]:



    assert len (data_prefix )%2 ==0 
    num_datasets :int =len (data_prefix )//2 
    weights :list [float ]=[0 ]*num_datasets 
    prefixes :list [str ]=[""]*num_datasets 

    for i in range (num_datasets ):
        weights [i ]=float (data_prefix [2 *i ])
        prefixes [i ]=(data_prefix [2 *i +1 ]).strip ()

    weight_sum =0.0 
    for weight in weights :
        weight_sum +=weight 
    assert weight_sum >0.0 
    weights =[weight /weight_sum for weight in weights ]




    if isinstance (train_valid_test_num_samples ,list ):
        datasets_train_valid_test_num_samples :list [int ]=[]
        for weight in weights :
            datasets_train_valid_test_num_samples .append (
            [int (math .ceil (val *weight *1.005 ))for val in train_valid_test_num_samples ])
    else :


        datasets_train_valid_test_num_samples =[
        int (math .ceil (train_valid_test_num_samples *weight *1.005 ))
        for weight in weights ]

    return prefixes ,weights ,datasets_train_valid_test_num_samples 


def get_a_and_b_segments (sample ,np_rng ):
    """Divide sample into a and b segments."""


    n_sentences =len (sample )

    assert n_sentences >1 ,'make sure each sample has at least two sentences.'



    a_end =1 
    if n_sentences >=3 :

        a_end =np_rng .randint (1 ,n_sentences )
    tokens_a =[]
    for j in range (a_end ):
        tokens_a .extend (sample [j ])


    tokens_b =[]
    for j in range (a_end ,n_sentences ):
        tokens_b .extend (sample [j ])


    is_next_random =False 
    if np_rng .random ()<0.5 :
        is_next_random =True 
        tokens_a ,tokens_b =tokens_b ,tokens_a 

    return tokens_a ,tokens_b ,is_next_random 


def truncate_segments (tokens_a ,tokens_b ,len_a ,len_b ,max_num_tokens ,np_rng ):
    """Truncates a pair of sequences to a maximum sequence length."""
    assert len_a >0 

    if len_a +len_b <=max_num_tokens :
        return False 

    while len_a +len_b >max_num_tokens :
        if len_a >len_b :
            len_a -=1 
            tokens =tokens_a 
        else :
            len_b -=1 
            tokens =tokens_b 
        if np_rng .random ()<0.5 :
            del tokens [0 ]
        else :
            tokens .pop ()
    return True 


def create_tokens_and_tokentypes (tokens_a ,tokens_b ,cls_id ,sep_id ):
    """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""

    tokens =[]
    tokentypes =[]

    tokens .append (cls_id )
    tokentypes .append (0 )

    for token in tokens_a :
        tokens .append (token )
        tokentypes .append (0 )

    tokens .append (sep_id )
    tokentypes .append (0 )

    for token in tokens_b :
        tokens .append (token )
        tokentypes .append (1 )
    if tokens_b :

        tokens .append (sep_id )
        tokentypes .append (1 )

    return tokens ,tokentypes 


MaskedLmInstance =collections .namedtuple ("MaskedLmInstance",
["index","label"])


def is_start_piece (piece ):
    """Check if the current word piece is the starting piece (BERT)."""




    return not piece .startswith ("##")


def create_masked_lm_predictions (tokens ,
vocab_id_list ,vocab_id_to_token_dict ,
masked_lm_prob ,
cls_id ,sep_id ,mask_id ,
max_predictions_per_seq ,
np_rng ,
max_ngrams =3 ,
do_whole_word_mask =True ,
favor_longer_ngram =False ,
do_permutation =False ,
geometric_dist =False ,
masking_style ="bert"):
    """Creates the predictions for the masked LM objective.
    Note: Tokens here are vocab ids and not text tokens."""

    cand_indexes =[]



    token_boundary =[0 ]*len (tokens )

    for (i ,token )in enumerate (tokens ):
        if token ==cls_id or token ==sep_id :
            token_boundary [i ]=1 
            continue 






        if (do_whole_word_mask and len (cand_indexes )>=1 and 
        not is_start_piece (vocab_id_to_token_dict [token ])):
            cand_indexes [-1 ].append (i )
        else :
            cand_indexes .append ([i ])
            if is_start_piece (vocab_id_to_token_dict [token ]):
                token_boundary [i ]=1 

    output_tokens =list (tokens )

    masked_lm_positions =[]
    masked_lm_labels =[]

    if masked_lm_prob ==0 :
        return (output_tokens ,masked_lm_positions ,
        masked_lm_labels ,token_boundary )

    num_to_predict =min (max_predictions_per_seq ,
    max (1 ,int (round (len (tokens )*masked_lm_prob ))))

    ngrams =np .arange (1 ,max_ngrams +1 ,dtype =np .int64 )
    if not geometric_dist :


        pvals =1. /np .arange (1 ,max_ngrams +1 )
        pvals /=pvals .sum (keepdims =True )
        if favor_longer_ngram :
            pvals =pvals [::-1 ]

    ngram_indexes =[]
    for idx in range (len (cand_indexes )):
        ngram_index =[]
        for n in ngrams :
            ngram_index .append (cand_indexes [idx :idx +n ])
        ngram_indexes .append (ngram_index )

    np_rng .shuffle (ngram_indexes )

    (masked_lms ,masked_spans )=([],[])
    covered_indexes =set ()
    for cand_index_set in ngram_indexes :
        if len (masked_lms )>=num_to_predict :
            break 
        if not cand_index_set :
            continue 


        for index_set in cand_index_set [0 ]:
            for index in index_set :
                if index in covered_indexes :
                    continue 

        if not geometric_dist :
            n =np_rng .choice (ngrams [:len (cand_index_set )],
            p =pvals [:len (cand_index_set )]/
            pvals [:len (cand_index_set )].sum (keepdims =True ))
        else :



            n =min (np_rng .geometric (0.2 ),max_ngrams )

        index_set =sum (cand_index_set [n -1 ],[])
        n -=1 



        while len (masked_lms )+len (index_set )>num_to_predict :
            if n ==0 :
                break 
            index_set =sum (cand_index_set [n -1 ],[])
            n -=1 


        if len (masked_lms )+len (index_set )>num_to_predict :
            continue 
        is_any_index_covered =False 
        for index in index_set :
            if index in covered_indexes :
                is_any_index_covered =True 
                break 
        if is_any_index_covered :
            continue 
        for index in index_set :
            covered_indexes .add (index )
            masked_token =None 
            if masking_style =="bert":

                if np_rng .random ()<0.8 :
                    masked_token =mask_id 
                else :

                    if np_rng .random ()<0.5 :
                        masked_token =tokens [index ]

                    else :
                        masked_token =vocab_id_list [np_rng .randint (0 ,len (vocab_id_list ))]
            elif masking_style =="t5":
                masked_token =mask_id 
            else :
                raise ValueError ("invalid value of masking style")

            output_tokens [index ]=masked_token 
            masked_lms .append (MaskedLmInstance (index =index ,label =tokens [index ]))

        masked_spans .append (MaskedLmInstance (
        index =index_set ,
        label =[tokens [index ]for index in index_set ]))

    assert len (masked_lms )<=num_to_predict 
    np_rng .shuffle (ngram_indexes )

    select_indexes =set ()
    if do_permutation :
        for cand_index_set in ngram_indexes :
            if len (select_indexes )>=num_to_predict :
                break 
            if not cand_index_set :
                continue 


            for index_set in cand_index_set [0 ]:
                for index in index_set :
                    if index in covered_indexes or index in select_indexes :
                        continue 

            n =np .random .choice (ngrams [:len (cand_index_set )],
            p =pvals [:len (cand_index_set )]/
            pvals [:len (cand_index_set )].sum (keepdims =True ))
            index_set =sum (cand_index_set [n -1 ],[])
            n -=1 

            while len (select_indexes )+len (index_set )>num_to_predict :
                if n ==0 :
                    break 
                index_set =sum (cand_index_set [n -1 ],[])
                n -=1 


            if len (select_indexes )+len (index_set )>num_to_predict :
                continue 
            is_any_index_covered =False 
            for index in index_set :
                if index in covered_indexes or index in select_indexes :
                    is_any_index_covered =True 
                    break 
            if is_any_index_covered :
                continue 
            for index in index_set :
                select_indexes .add (index )
        assert len (select_indexes )<=num_to_predict 

        select_indexes =sorted (select_indexes )
        permute_indexes =list (select_indexes )
        np_rng .shuffle (permute_indexes )
        orig_token =list (output_tokens )

        for src_i ,tgt_i in zip (select_indexes ,permute_indexes ):
            output_tokens [src_i ]=orig_token [tgt_i ]
            masked_lms .append (MaskedLmInstance (index =src_i ,label =orig_token [src_i ]))

    masked_lms =sorted (masked_lms ,key =lambda x :x .index )

    masked_spans =sorted (masked_spans ,key =lambda x :x .index [0 ])

    for p in masked_lms :
        masked_lm_positions .append (p .index )
        masked_lm_labels .append (p .label )
    return (output_tokens ,masked_lm_positions ,masked_lm_labels ,token_boundary ,masked_spans )


def pad_and_convert_to_numpy (tokens ,tokentypes ,masked_positions ,
masked_labels ,pad_id ,max_seq_length ):
    """Pad sequences and convert them to numpy."""


    num_tokens =len (tokens )
    padding_length =max_seq_length -num_tokens 
    assert padding_length >=0 
    assert len (tokentypes )==num_tokens 
    assert len (masked_positions )==len (masked_labels )


    filler =[pad_id ]*padding_length 
    tokens_np =np .array (tokens +filler ,dtype =np .int64 )
    tokentypes_np =np .array (tokentypes +filler ,dtype =np .int64 )


    padding_mask_np =np .array ([1 ]*num_tokens +[0 ]*padding_length ,
    dtype =np .int64 )


    labels =[-1 ]*max_seq_length 
    loss_mask =[0 ]*max_seq_length 
    for i in range (len (masked_positions )):
        assert masked_positions [i ]<num_tokens 
        labels [masked_positions [i ]]=masked_labels [i ]
        loss_mask [masked_positions [i ]]=1 
    labels_np =np .array (labels ,dtype =np .int64 )
    loss_mask_np =np .array (loss_mask ,dtype =np .int64 )

    return tokens_np ,tokentypes_np ,labels_np ,padding_mask_np ,loss_mask_np 


def build_train_valid_test_datasets_with_prefixes (train_valid_test_num_samples ,
max_seq_length ,
seed ,
train_data_prefix =None ,
valid_data_prefix =None ,
test_data_prefix =None ,
binary_head =False ,
max_seq_length_dec =None ,
dataset_type ='standard_bert'):
    print_rank_0 ("Separate data paths provided for train, valid & test.")

    train_dataset ,valid_dataset ,test_dataset =None ,None ,None 

    if train_data_prefix is not None :
        train_dataset =build_dataset ("train",train_data_prefix ,
        train_valid_test_num_samples [0 ],
        max_seq_length ,seed ,
        binary_head ,max_seq_length_dec ,
        dataset_type =dataset_type )

    if valid_data_prefix is not None :
        valid_dataset =build_dataset ("valid",valid_data_prefix ,
        train_valid_test_num_samples [1 ],
        max_seq_length ,seed ,False ,
        binary_head ,max_seq_length_dec ,
        dataset_type =dataset_type )

    if test_data_prefix is not None :
        test_dataset =build_dataset ("test",test_data_prefix ,
        train_valid_test_num_samples [2 ],
        max_seq_length ,seed ,False ,
        binary_head ,max_seq_length_dec ,
        dataset_type =dataset_type )

    return (train_dataset ,valid_dataset ,test_dataset )


def build_train_valid_test_datasets (data_prefix ,splits_string ,
train_valid_test_num_samples ,
max_seq_length ,seed ,
binary_head =False ,
max_seq_length_dec =None ,
dataset_type ='standard_bert'):

    if len (data_prefix )==1 :
        return _build_train_valid_test_datasets (data_prefix [0 ],
        splits_string ,
        train_valid_test_num_samples ,
        max_seq_length ,seed ,
        binary_head ,
        max_seq_length_dec ,
        dataset_type =dataset_type )

    raise NotImplementedError ("Blending currently unsupported for non-GPT dataset instances")


def _build_train_valid_test_datasets (data_prefix ,splits_string ,
train_valid_test_num_samples ,
max_seq_length ,seed ,
binary_head ,
max_seq_length_dec ,
dataset_type ='standard_bert'):


    indexed_dataset =get_indexed_dataset_ (data_prefix ,
    dataset_type )




    total_num_of_documents =indexed_dataset .document_indices .shape [0 ]-1 
    splits =get_train_valid_test_split_ (splits_string ,total_num_of_documents )


    print_rank_0 (' > dataset split:')

    def print_split_stats (name ,index ):
        print_rank_0 ('    {}:'.format (name ))
        print_rank_0 ('     document indices in [{}, {}) total of {} '
        'documents'.format (splits [index ],splits [index +1 ],
        splits [index +1 ]-splits [index ]))
        start_index =indexed_dataset .document_indices [splits [index ]]
        end_index =indexed_dataset .document_indices [splits [index +1 ]]
        print_rank_0 ('     sentence indices in [{}, {}) total of {} '
        'sentences'.format (start_index ,end_index ,
        end_index -start_index ))
    print_split_stats ('train',0 )
    print_split_stats ('validation',1 )
    print_split_stats ('test',2 )

    def build_split_dataset (index ,name ):
        dataset =None 
        if splits [index +1 ]>splits [index ]:

            doc_idx_ptr =indexed_dataset .get_document_indices ()

            start_index =splits [index ]

            end_index =splits [index +1 ]+1 

            indexed_dataset .set_document_indices (doc_idx_ptr [start_index :end_index ])

            dataset =build_dataset (
            name ,data_prefix ,
            train_valid_test_num_samples [index ],max_seq_length ,
            seed ,binary_head ,max_seq_length_dec ,
            dataset_type ,indexed_dataset )


            indexed_dataset .set_document_indices (doc_idx_ptr )

            assert indexed_dataset .document_indices [0 ]==0 
            assert indexed_dataset .document_indices .shape [0 ]==(total_num_of_documents +1 )
        return dataset 

    train_dataset =build_split_dataset (0 ,'train')
    valid_dataset =build_split_dataset (1 ,'valid')
    test_dataset =build_split_dataset (2 ,'test')

    return (train_dataset ,valid_dataset ,test_dataset )


def build_dataset (
name ,
data_prefix ,
max_num_samples ,
max_seq_length ,
seed ,
binary_head ,
max_seq_length_dec ,
dataset_type ='standard_bert',
indexed_dataset =None 
):

    from megatron_lm .megatron .data .bert_dataset import BertDataset 

    if dataset_type not in DSET_TYPES :
        raise ValueError ("Invalid dataset_type: ",dataset_type )

    if indexed_dataset is None :
        indexed_dataset =get_indexed_dataset_ (
        data_prefix ,dataset_type 
        )

    kwargs =dict (
    name =name ,
    data_prefix =data_prefix ,
    num_epochs =None ,
    max_num_samples =max_num_samples ,
    max_seq_length =max_seq_length ,
    seed =seed ,
    )

    if dataset_type ==DSET_TYPE_BERT :
        args =get_args ()
        dataset =BertDataset (
        indexed_dataset =indexed_dataset ,
        masked_lm_prob =args .mask_prob ,
        short_seq_prob =args .short_seq_prob ,
        binary_head =binary_head ,
        **kwargs 
        )
    else :
        raise NotImplementedError ("Dataset type not fully implemented.")

    return dataset 


def get_indexed_dataset_ (data_prefix ,dataset_type ):

    print_rank_0 (' > building dataset index ...')

    start_time =time .time ()
    multimodal =False 
    indexed_dataset =MMapIndexedDataset (data_prefix ,multimodal )
    assert indexed_dataset .sequence_lengths .shape [0 ]==indexed_dataset .document_indices [-1 ]
    print_rank_0 (' > finished creating indexed dataset in {:4f} seconds'.format (time .time ()-start_time ))

    print_rank_0 (' > indexed dataset stats:')
    print_rank_0 ('    number of documents: {}'.format (indexed_dataset .document_indices .shape [0 ]-1 ))
    print_rank_0 ('    number of sentences: {}'.format (indexed_dataset .sequence_lengths .shape [0 ]))

    return indexed_dataset 


def get_train_valid_test_split_ (splits_string ,size ):
    """ Get dataset splits from comma or '/' separated string list."""

    splits =[]
    if splits_string .find (',')!=-1 :
        splits =[float (s )for s in splits_string .split (',')]
    elif splits_string .find ('/')!=-1 :
        splits =[float (s )for s in splits_string .split ('/')]
    else :
        splits =[float (splits_string )]
    while len (splits )<3 :
        splits .append (0. )
    splits =splits [:3 ]
    splits_sum =sum (splits )
    assert splits_sum >0.0 
    splits =[split /splits_sum for split in splits ]
    splits_index =[0 ]
    for index ,split in enumerate (splits ):
        splits_index .append (splits_index [index ]+
        int (round (split *float (size ))))
    diff =splits_index [-1 ]-size 
    for index in range (1 ,len (splits_index )):
        splits_index [index ]-=diff 
    assert len (splits_index )==4 
    assert splits_index [-1 ]==size 
    return splits_index 


def get_samples_mapping (indexed_dataset ,
data_prefix ,
num_epochs ,
max_num_samples ,
max_seq_length ,
short_seq_prob ,
seed ,
name ,
binary_head ):
    """Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""

    if not num_epochs :
        if not max_num_samples :
            raise ValueError ("Need to specify either max_num_samples "
            "or num_epochs")
        num_epochs =np .iinfo (np .int32 ).max -1 
    if not max_num_samples :
        max_num_samples =np .iinfo (np .int64 ).max -1 


    indexmap_filename =data_prefix 
    indexmap_filename +='_{}_indexmap'.format (name )
    if num_epochs !=(np .iinfo (np .int32 ).max -1 ):
        indexmap_filename +='_{}ep'.format (num_epochs )
    if max_num_samples !=(np .iinfo (np .int64 ).max -1 ):
        indexmap_filename +='_{}mns'.format (max_num_samples )
    indexmap_filename +='_{}msl'.format (max_seq_length )
    indexmap_filename +='_{:0.2f}ssp'.format (short_seq_prob )
    indexmap_filename +='_{}s'.format (seed )
    indexmap_filename +='.npy'


    if torch .distributed .get_rank ()==0 and not os .path .isfile (indexmap_filename ):
        print (' > WARNING: could not find index map file {}, building the indices on rank 0 ...'.format (indexmap_filename ))


        assert indexed_dataset .document_indices .dtype ==np .int64 
        assert indexed_dataset .sequence_lengths .dtype ==np .int32 


        verbose =torch .distributed .get_rank ()==0 
        start_time =time .time ()
        print_rank_0 (' > building samples index mapping for {} ...'.format (
        name ))

        from megatron_lm .megatron .core .datasets import helpers 
        samples_mapping =helpers .build_mapping (
        indexed_dataset .document_indices ,
        indexed_dataset .sequence_lengths ,
        num_epochs ,
        max_num_samples ,
        max_seq_length ,
        short_seq_prob ,
        seed ,
        verbose ,
        2 if binary_head else 1 )
        print_rank_0 (' > done building samples index mapping')
        np .save (indexmap_filename ,samples_mapping ,allow_pickle =True )
        print_rank_0 (' > saved the index mapping in {}'.format (
        indexmap_filename ))

        print_rank_0 (' > elapsed time to build and save samples mapping (seconds): {:4f}'.format (
        time .time ()-start_time ))



    counts =torch .tensor ([1 ],dtype =torch .long ,device ='cuda')
    torch_distributed .all_reduce (counts )
    assert counts [0 ].item ()==(torch .distributed .get_world_size ())


    print_rank_0 (' > loading indexed mapping from {}'.format (
    indexmap_filename ))
    start_time =time .time ()
    samples_mapping =np .load (indexmap_filename ,allow_pickle =True ,mmap_mode ='r')
    print_rank_0 ('    loaded indexed file in {:3.3f} seconds'.format (
    time .time ()-start_time ))
    print_rank_0 ('    total number of samples: {}'.format (
    samples_mapping .shape [0 ]))

    return samples_mapping 
