

"""Dataloaders."""


import random 
import torch 
import numpy as np 

from typing import Union 
from torch .utils .data import Dataset 
import torch .distributed as torch_distributed 

from megatron_lm .megatron .global_vars import get_args 
from megatron_lm .megatron .core .datasets .blended_dataset import BlendedDataset 
from megatron_lm .megatron .core .datasets .megatron_dataset import MegatronDataset 


def build_pretraining_data_loader (
dataset :list [Union [BlendedDataset ,MegatronDataset ]],
consumed_samples :int 
):
    """Build dataloader given an input dataset."""

    if dataset is None :
        return None 
    args =get_args ()


    batch_sampler =MegatronPretrainingSampler (
    total_samples =len (dataset ),
    consumed_samples =consumed_samples ,
    micro_batch_size =args .micro_batch_size ,
    data_parallel_rank =torch_distributed .get_rank (),
    data_parallel_size =torch_distributed .get_world_size (),
    )


    return torch .utils .data .DataLoader (
    dataset ,
    batch_sampler =batch_sampler ,
    num_workers =args .num_workers ,
    pin_memory =True 
    )


class MegatronPretrainingSampler :

    def __init__ (
    self ,
    total_samples :int ,
    consumed_samples :int ,
    micro_batch_size :int ,
    data_parallel_rank :int ,
    data_parallel_size :int ,
    drop_last =True 
    )->None :

        self .total_samples =total_samples 
        self .consumed_samples =consumed_samples 
        self .micro_batch_size =micro_batch_size 
        self .data_parallel_rank =data_parallel_rank 
        self .micro_batch_times_data_parallel_size =self .micro_batch_size *data_parallel_size 
        self .drop_last =drop_last 


        assert self .total_samples >0 ,'no sample to consume: {}'.format (self .total_samples )
        assert self .consumed_samples <self .total_samples ,'no samples left to consume: {}, {}'.format (
        self .consumed_samples ,self .total_samples 
        )
        assert self .micro_batch_size >0 
        assert data_parallel_size >0 
        assert self .data_parallel_rank <data_parallel_size ,'data_parallel_rank should be smaller than data size: {} {}'.format (self .data_parallel_rank ,data_parallel_size )

    def __len__ (self )->int :
        return self .total_samples 

    def get_start_end_idx (self )->tuple [int ,int ]:
        start_idx =self .data_parallel_rank *self .micro_batch_size 
        end_idx =start_idx +self .micro_batch_size 
        return start_idx ,end_idx 

    def __iter__ (self ):
        batch =[]

        for idx in range (self .consumed_samples ,self .total_samples ):
            batch .append (idx )
            if len (batch )==self .micro_batch_times_data_parallel_size :
                start_idx ,end_idx =self .get_start_end_idx ()
                yield batch [start_idx :end_idx ]
                batch =[]


        if len (batch )>0 and not self .drop_last :
            start_idx ,end_idx =self .get_start_end_idx ()
            yield batch [start_idx :end_idx ]


class RandomSeedDataset (Dataset ):

    def __init__ (self ,dataset ):
        args =get_args ()
        self .base_seed =args .seed 
        self .curr_seed =args .seed 
        self .dataset =dataset 

    def __len__ (self ):
        return len (self .dataset )

    def set_epoch (self ,epoch ):
        self .curr_seed =self .base_seed +epoch 

    def __getitem__ (self ,idx ):
        seed =idx +self .curr_seed 
        torch .manual_seed (seed )
        random .seed (seed )
        np .random .seed (seed )
        return self .dataset [idx ]


class MegatronPretrainingRandomSampler :

    def __init__ (
    self ,
    dataset ,
    total_samples :int ,
    consumed_samples :int ,
    micro_batch_size :int ,
    data_parallel_rank :int ,
    data_parallel_size :int ,
    data_sharding :bool ,
    )->None :

        self .dataset =dataset 
        self .total_samples =total_samples 
        self .consumed_samples =consumed_samples 
        self .micro_batch_size =micro_batch_size 
        self .data_parallel_rank =data_parallel_rank 
        self .data_parallel_size =data_parallel_size 
        self .data_sharding =data_sharding 
        self .micro_batch_times_data_parallel_size =self .micro_batch_size *data_parallel_size 
        self .last_batch_size =self .total_samples %self .micro_batch_times_data_parallel_size 


        assert self .total_samples >0 ,'no sample to consume: {}'.format (self .total_samples )
        assert self .micro_batch_size >0 
        assert data_parallel_size >0 
        assert self .data_parallel_rank <data_parallel_size ,'data_parallel_rank should be smaller than data size: {}, '  '{}'.format (self .data_parallel_rank ,data_parallel_size )

    def __len__ (self )->int :
        return self .total_samples 

    def __iter__ (self ):
        active_total_samples =self .total_samples -self .last_batch_size 
        self .epoch =self .consumed_samples //active_total_samples 
        current_epoch_samples =self .consumed_samples %active_total_samples 

        assert current_epoch_samples %self .micro_batch_times_data_parallel_size ==0 

        if isinstance (self .dataset ,RandomSeedDataset ):
            self .dataset .set_epoch (self .epoch )


        if self .data_sharding :
            bucket_size :int =(self .total_samples //self .micro_batch_times_data_parallel_size )*self .micro_batch_size 

            bucket_offset =current_epoch_samples //self .data_parallel_size 
            start_idx =self .data_parallel_rank *bucket_size 

            g =torch .Generator ()
            g .manual_seed (self .epoch )
            random_idx =torch .randperm (bucket_size ,generator =g ).tolist ()
            idx_range =[start_idx +x for x in random_idx [bucket_offset :]]
        else :
            full_bucket_size =(self .total_samples //self .micro_batch_size )*self .micro_batch_size 
            full_bucket_offset =current_epoch_samples 
            g =torch .Generator ()
            g .manual_seed (self .epoch )
            idx_range_total =torch .randperm (full_bucket_size ,generator =g ).tolist ()
            idx_range_active =idx_range_total [full_bucket_offset :]
            idx_range =idx_range_active [self .data_parallel_rank ::self .data_parallel_size ]

        batch =[]

        for idx in idx_range :
            batch .append (idx )
            if len (batch )==self .micro_batch_size :
                self .consumed_samples +=self .micro_batch_times_data_parallel_size 
                yield batch 
                batch =[]
