from torch.utils import data
import torch
import numpy as np
import pickle 
import os    
       
from multiprocessing import Process, Manager   



class Utterances(data.Dataset):
    """Dataset class for the Utterances dataset."""

    def __init__(self, root_dir, len_crop):
        """Initialize and preprocess the Utterances dataset."""
        self.root_dir = root_dir
        self.len_crop = len_crop
        self.step = 10
        
        metaname = os.path.join(self.root_dir, "train.pkl")
        meta = pickle.load(open(metaname, "rb"))#d_vector
        #print("test meta:", meta)
        """Load data using multiprocessing"""
        manager = Manager()
        meta = manager.list(meta)
        dataset = manager.list(len(meta)*[None])  
        processes = []
        for i in range(0, len(meta), self.step):
            p = Process(target=self.load_data, 
                        args=(meta[i:i+self.step],dataset,i))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
            
        self.train_dataset = list(dataset)
        self.num_tokens = len(self.train_dataset)
        
        print('Finished loading the dataset...')
        
        
    def load_data(self, submeta, dataset, idx_offset):
        for k, sbmt in enumerate(submeta):
            uttrs = len(sbmt)*[None]
            for j, tmp in enumerate(sbmt):
                if j < 2:  # fill in speaker id and embedding
                    uttrs[j] = tmp
                else: # load the mel-spectrograms
                    uttrs[j] = np.load(os.path.join(self.root_dir, tmp))
            dataset[idx_offset+k] = uttrs
                   
        
    def __getitem__(self, index):
        # pick a random speaker
        label = index
        dataset = self.train_dataset
        print("test len index:", len(dataset))
        list_uttrs = dataset[index]
        print('shape of list-uttrs', np.array(list_uttrs).shape)
        emb_org = list_uttrs[1]
        a = np.random.randint(2, len(list_uttrs))
        tmp = list_uttrs[a]
        if tmp.shape[0]<self.len_crop:
            len_pad = self.len_crop - tmp.shape[0]
            uttr = np.pad(tmp, ((0, len_pad), (0, 0)), 'constant')
        elif tmp.shape[0]> len_crop:
            left = np.random.randint(tmp.shape[0] - len_crop)
            uttr = tmp[left:left+self.len_crop, :]
        else:
            uttr = tmp
        return uttr, emb_org, label

    def __len__(self):
        """Return the number of spkrs."""
        return self.num_tokens
    
    
    

def get_loader(root_dir, batch_size=16, len_crop=128, num_workers=0):
    """Build and return a data loader."""
    
    dataset = Utterances(root_dir, len_crop)
    
    worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32))
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  drop_last=True,
                                  worker_init_fn=worker_init_fn)
    return data_loader

#以下为测试部分
if __name__=='__main__':
    root_dir='../DataSets/data_aishell3/data_aishell3/train/autovc_mel_train'
    batch_size = 4
    len_crop = 128
    vcc_loader = get_loader(root_dir, batch_size, len_crop)
    data_iter = iter(vcc_loader)
    x_real, emb_org, label = next(data_iter)
    x_real= x_real.cuda()
    print("test：x_real:",x_real.shape)#(4,128,80) 这里4代表batchsize
    print("test: emb_org:",emb_org.shape)
    print("test: label")



