# coding: utf-8

import os.path,glob,string
import torch
from torch.nn.utils.rnn import pad_sequence
import torchaudio
from torchaudio.datasets.librispeech import _get_librispeech_metadata,_load_waveform

class TIMIT(object):
	"""
	LibriSpeech dataset for merging the 960 training data.
	"""
	_ext_txt = ".PHN"
	_ext_audio = ".WAV"
	def __init__(self, root, split):
		self.root = root
		self._walker = sorted(glob.glob(os.path.join(split.upper(),"*/*/*" + self._ext_audio),
										root_dir=self.root))
		self._build_inventory()

	def _build_inventory(self):
		blank = '_'
		ORIGINAL_INVENTORY = [
			'iy','ih','eh','ey','ae','aa','aw','ay','ah','ao',
			'oy','ow','uh','uw','ux','er','ax','ix','axr','ax-h',
			'jh','ch','b','d','g','p','t','k','dx','s',
			'sh','z','zh','f','th','v','dh','m','n','ng',
			'em','nx','en','eng','l','r','w','y','hh','hv',
			'el','bcl','dcl','gcl','pcl','tcl','kcl','q','pau','epi',
			'h#'
		]
		REDUCTION_MAPPING = {
			'ao':'aa',
			'ax':'ah',
			'ax-h':'ah',
			'axr':'er',
			'hv':'hh',
			'ix':'ih',
			'el':'l',
			'em':'m',
			'en':'n',
			'nx':'n',
			'eng':'ng',
			'zh':'sh',
			'ux':'uw',
			'pcl':'sil',
			'tcl':'sil',
			'kcl':'sil',
			'bcl':'sil',
			'dcl':'sil',
			'gcl':'sil',
			'h#':'sil',
			'pau':'sil',
			'epi':'sil',
		}
		self.idx2label = [blank]+list(set([REDUCTION_MAPPING.get(phone, phone)
											for phone in ORIGINAL_INVENTORY
											if phone!='q']))
		label2idx = {c:i for i,c in enumerate(self.idx2label)}
		self.phone2idx = {phone:label2idx[REDUCTION_MAPPING.get(phone,phone)]
											for phone in ORIGINAL_INVENTORY
											if phone!='q'}

	def get_metadata(self, idx):
		subpath = self._walker[idx]
		phn_path = os.path.join(self.root, os.path.splitext(subpath)[0]+self._ext_txt)
		with open(phn_path, 'r') as f:
			target = [line.rstrip('\n').split(' ')[-1] for line in f.readlines()]
		target = [self.phone2idx[phone] for phone in target if phone!='q']
		return subpath,target

	def __getitem__(self, idx):
		subpath,target = self.get_metadata(idx)
		waveform,fs = torchaudio.load(os.path.join(self.root, subpath), normalize=True)
		waveform = waveform[0] # single channel.
		waveform = waveform / waveform.abs().max() # TIMIT data sounds smaller so rescale.
		target = torch.tensor(target)
		return waveform,fs,target
	
	def __len__(self):
		return len(self._walker)
	
	@property
	def vocab_size(self):
		return len(self.idx2label)
	
	def collate_fn(self,batch):
		waveform,fs,target = zip(*batch)
		wav_lengths = torch.tensor([w.size(0) for w in waveform])
		waveform = pad_sequence(waveform,batch_first=True)
		target_lengths = torch.tensor([t.size(0) for t in target])
		target = pad_sequence(target,batch_first=True,padding_value=-1)
		fs = torch.tensor(fs)
		return waveform,wav_lengths,fs,target,target_lengths