import torch
from torch.utils.data import DataLoader


class DataLoaderX(DataLoader):
    def __init__(self, local_rank, **kwargs):
        super().__init__(**kwargs)
        self.stream = torch.cuda.Stream(
            local_rank
        )  # create a new cuda stream in each process
        self.local_rank = local_rank
        # self.custom_collect_fn = custom_collect_fn
        
    def __iter__(self):
        self.iter = super().__iter__()
        self.preload()
        return self


    def preload(self):
        while True:
            #获取下一个值
            self.batch = next(self.iter, None)
            if self.batch is not None:
                break
            if self.iter._send_idx==len(self.iter):
                break
        
        if (self.batch is None):
            return None

        with torch.cuda.stream(self.stream): # 将数据预先放进gpu
            for key, val in self.batch.items():
                if type(val) == torch.Tensor:
                    self.batch[key] = val.to(
                            device=self.local_rank, non_blocking=True
                        )


    def __next__(self):
        torch.cuda.current_stream().wait_stream(
            self.stream
        )  # wait tensor to put on GPU
        batch = self.batch
        # batch = self.custom_collect_fn(self.batch)
        if batch is None:
            raise StopIteration
        self.preload()
        return batch

