# coding: utf-8

import os.path,glob,string
import torch
from torch.nn.utils.rnn import pad_sequence
from torchaudio.datasets.librispeech import _get_librispeech_metadata,_load_waveform

class LibriSpeech(object):
	"""
	LibriSpeech dataset for merging the 960 training data.
	"""
	_ext_txt = ".trans.txt"
	_ext_audio = ".flac"
	def __init__(self, root, split, max_length=None):
		self.root = root
		self._walker = sorted(p.split('/')[0]+'|'+os.path.splitext(os.path.basename(p))[0]
						for p in glob.glob(os.path.join(split,"*/*/*" + self._ext_audio),
											root_dir=self.root))
		self.max_length = max_length
		self.idx2char = "-|etaonihsrdlumwcfgypbvk'xjqz" # NOTE: Corresponding to torchaudio.models.decoder.ctc_decoder
		self.char2idx = {c:i for i,c in enumerate(self.idx2char)}

	def get_metadata(self, idx):
		fileid = self._walker[idx]
		split,fileid = fileid.split('|')
		return _get_librispeech_metadata(fileid, self.root, split, self._ext_audio, self._ext_txt)

	def __getitem__(self, idx):
		subpath,fs,target,*metadata = self.get_metadata(idx)
		waveform = _load_waveform(self.root, subpath, fs)[0] # 0th dim indexes recording channels.
		target = target.lower()
		if self.max_length is None:
			target = torch.tensor([self.char2idx[c] for c in target.replace(' ','|')])
		if (not self.max_length is None) and waveform.size(0)>self.max_length:
			onset = torch.randint(waveform.size(0)-self.max_length, size=(1,)).item()
			waveform = waveform[onset:onset+self.max_length]
		return waveform,fs,target,*metadata
	
	def __len__(self):
		return len(self._walker)
	
	@property
	def vocab_size(self):
		return len(self.char2idx)
	
	def collate_fn(self,batch):
		waveform,fs,target,*metadata = zip(*batch)
		wav_lengths = torch.tensor([w.size(0) for w in waveform])
		waveform = pad_sequence(waveform,batch_first=True)
		if torch.is_tensor(target[0]):
			target_lengths = torch.tensor([t.size(0) for t in target])
			target = pad_sequence(target,batch_first=True,padding_value=-1)
		else:
			target_lengths = None
		fs = torch.tensor(fs)
		return waveform,wav_lengths,fs,target,target_lengths,*metadata