from torch.utils.data import Dataset
import random
import torch
import os
from pathlib import Path
from torch.utils.data import IterableDataset
import torch.distributed as dist
import nn_ss.dataloader.processor as processor

class Processor(IterableDataset):
    def __init__(self, source, f, *args, **kw):
        assert callable(f)
        self.source = source
        self.f = f
        self.args = args
        self.kw = kw

    def set_epoch(self, epoch):
        self.source.set_epoch(epoch)

    def __iter__(self):
        """ Return an iterator over the source dataset processed by the
            given processor.
        """
        assert self.source is not None
        assert callable(self.f)
        return self.f(iter(self.source), *self.args, **self.kw)

    def apply(self, f):
        assert callable(f)
        return Processor(self, f, *self.args, **self.kw)

    def __len__(self):
        return len(self.source)

class DistributedSampler:
    def __init__(self, shuffle=True, partition=True):
        self.epoch = -1
        self.update()
        self.shuffle = shuffle
        self.partition = partition

    def update(self):
        # assert dist.is_available()
        # print("1111111111:",dist.is_initialized())
        # print("333333333:", dist.get_world_size())
        if dist.is_initialized():
            self.rank = dist.get_rank()
            # print("2222222:",self.rank)
            self.world_size = dist.get_world_size()
        else:
            self.rank = 0
            self.world_size = 1
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            self.worker_id = 0
            self.num_workers = 1
        else:
            self.worker_id = worker_info.id
            self.num_workers = worker_info.num_workers
        print("self.rank:",self.rank,self.world_size,self.worker_id,self.num_workers)
        return dict(rank=self.rank,
                    world_size=self.world_size,
                    worker_id=self.worker_id,
                    num_workers=self.num_workers)

    def set_epoch(self, epoch):
        self.epoch = epoch

    def sample(self, data):
        """ Sample data according to rank/world_size/num_workers

            Args:
                data(List): input data list

            Returns:
                List: data list after sample
        """
        data = list(range(len(data)))
        # print("data:",data)
        # TODO(Binbin Zhang): fix this
        # We can not handle uneven data for CV on DDP, so we don't
        # sample data by rank, that means every GPU gets the same
        # and all the CV data
        print("self.rank:",self.rank)
        if self.partition:
            if self.shuffle:
                random.Random(self.epoch).shuffle(data)
            data = data[self.rank::self.world_size]
        data = data[self.worker_id::self.num_workers]
        # print("data:",data)
        return data

def read_lists(list_file):
    lists = []
    with open(list_file, 'r', encoding='utf8') as fin:
        for line in fin:
            lists.append(line.strip())
    return lists

class DataList(IterableDataset):
    def __init__(self, lists, shuffle=True, partition=True):
        self.lists = lists
        self.sampler = DistributedSampler(shuffle, partition)

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)

    def __iter__(self):
        sampler_info = self.sampler.update()
        indexes = self.sampler.sample(self.lists)
        for index in indexes:
            # yield dict(src=src)
            data = dict(src=self.lists[index])
            # print("data:",data)
            data.update(sampler_info)
            yield data

def SemanticDataset_rewrite(folder,
                 num_quant = 4,
                 stage='train',
                 max_frames_in_batch=12000,
                 partition = True):
        path = Path(folder)
        assert path.exists(), 'folder does not exist'
        # print("shuru:",os.path.join(folder, stage + '_shares_list'))
        lists = read_lists(os.path.join(folder, stage + '_shares_list_all'))  # 加入一列时长信息
        # print("lists:",lists)
        dataset = DataList(lists, shuffle=True, partition=partition)
        # print("111",dataset)
        dataset = Processor(dataset, processor.url_opener)
        # print("222",dataset)
        dataset = Processor(dataset, processor.tar_file_and_group)
        # print("222",dataset)
        # kk=0
        # for ii in dataset:
        #     kk+=1
        # print("222",kk)
        dataset = Processor(dataset, processor.load_raw,num_quant)
        # print("333",dataset)
        # kk=0
        # for ii in dataset:
        #     kk+=1
        # print("333",kk)
        # # print("333",dataset)
        # assert 0
        # print("data",dataset)
        dataset = Processor(dataset, processor.shuffle)
        dataset = Processor(dataset, processor.sort)
        # kk=0
        # for ii in dataset:
        #     kk+=1
        # print("444",kk)
        # print("444",dataset)
        # assert 0

        dataset = Processor(dataset, processor.dynamic_batch,max_frames_in_batch=max_frames_in_batch)
        dataset = Processor(dataset, processor.padding)
        return dataset



if __name__ == "__main__":
    folder = '/home/disk2/gongxuefei/DATA/soundstorm/data/aishell'
    trainset = SemanticDataset_rewrite(folder = folder)

    for item in trainset:
        pass
        # print("item:",item)



























