import os
from python_speech_features import mfcc
from scipy.io import wavfile
import glob
import numpy as np
from collections import namedtuple
import warnings
import pandas as pd
import sounddevice as sd
import time


class audioLoader(object):

	def __init__(self, config):
		"""
		config: configuration file
		mem: load the audio data to memory
		"""

		self.config=config
		self.soundSource = self.config.soundSource # a dictionary containing what to load
		self.param_func=namedtuple('sound_param', ['nFFT', 'windowLenTime','windowStepTime'])
		self.param_dict={
		'GoogleCommand': self.param_func(nFFT=512, windowLenTime=0.025, windowStepTime=0.01),
		'NSynth': self.param_func(nFFT=1024, windowLenTime=0.05, windowStepTime=0.04), 
		'UrbanSound': self.param_func(nFFT=1024, windowLenTime=0.05, windowStepTime=0.04),
		'FSC': self.param_func(nFFT=512, windowLenTime=0.025, windowStepTime=0.01),
		}
		self.fs = None  # audio sampling rate, 16000 for GoogleCommand, NSynth, UrbanSound
		self.words = {} # dict for storing raw audio signal
		self.env_type = os.path.split(self.config.envFolder)[0]

		if len(self.env_type)==0: self.env_type=self.config.envFolder

	def loadData(self): # it will be called in sheme_vec_env or dummy_vec_env
		if self.env_type == 'pybullet':
			self.numAllObject = len(self.config.objList)
			self.wordsIndex = []
			self.wordSplit = self.config.soundSource['size'][self.config.soundSource['dataset'][0]]
			for dataset in self.config.soundSource['dataset']:
				if dataset == 'FSC':
					self.loadFSCData_pybullet(loadSize=self.config.soundSource['size'][dataset])
				else:
					self.loadSoundData_pybullet(datasetName=dataset,
												soundList=self.config.soundSource['items'][dataset],
												loadSize=self.config.soundSource['size'][dataset])
			for i in range(self.numAllObject):
				self.wordsIndex.append(list(np.arange(len(self.words[i]))))

		elif self.env_type == 'ai2thor':
			self.audioDataFrame = {}  # a dict containing metadata for self.words
			self.transcription={}
			self.loadFSCData_ai2thor(loadSize=self.config.soundSource['size'])

		else:  # for GoogleCommand, NSynth, UrbanSound
			raise NotImplementedError

		print("Sound Loaded")

	def loadFSCData_ai2thor(self, loadSize=-1):
		"""
		This function loads fluent speech dataset (FSC) for ai2thor env
		"""
		soundSource=self.config.soundSource
		df = pd.read_csv(os.path.join(soundSource['dataset_path'], 'data', soundSource['csv']))
		# filter objects
		objs = soundSource['obj_act'].keys()
		df = df[df.object.isin(objs)]

		locs = soundSource['locations']
		for loc in locs:  # for each possible location
			# {'none': dataframe containing all location='none',
			# 'kitchen': dataframe containing location='kitchen'}
			loc_df = df[df.location.isin([loc])]  # filter location
			self.audioDataFrame[loc] = {}
			self.transcription[loc]={}
			self.words[loc]={}
			for obj in objs:  # for each object
				obj_df = loc_df[loc_df.object == obj]
				if not obj_df.empty:
					possible_act = soundSource['obj_act'][obj]
					self.audioDataFrame[loc][obj] = {}
					self.transcription[loc][obj]={}
					self.words[loc][obj] = {}
					for act in possible_act:
						self.audioDataFrame[loc][obj][act] = obj_df[obj_df.action == act]
						self.words[loc][obj][act]=[]
						self.transcription[loc][obj][act]=[]
						path_list=self.audioDataFrame[loc][obj][act]['path'].tolist()
						trans_list=self.audioDataFrame[loc][obj][act]['transcription'].tolist()
						idx=np.arange(len(path_list))
						np.random.shuffle(idx)
						for i in idx:
							self.fs, x = wavfile.read(os.path.join(self.config.soundSource['dataset_path'], path_list[i]))
							if x.size/self.fs>self.config.soundSource['max_sound_dur']:
								continue
							self.words[loc][obj][act].append(x)
							self.transcription[loc][obj][act].append(trans_list[i])

							if len(self.words[loc][obj][act])>=loadSize:
								break

	def loadFSCData_pybullet(self, loadSize=-1):
		soundSource = self.config.soundSource
		df = pd.read_csv(os.path.join(soundSource['dataset_path'], 'data', soundSource['csv']))
		# filter objects
		objs = list(set(soundSource['objs']))
		df = df[df.object.isin(objs)]

		# filter locations
		locs = soundSource['locations']
		df= df[df.location.isin(locs)]

		for i in range(self.numAllObject):
			if i not in self.words:
				self.words[i] = []
			subdf = df[df.object==soundSource['objs'][i]]
			subdf = subdf[subdf.action==soundSource['act'][i]]
			path_list=subdf['path'].tolist()
			idx = np.arange(len(path_list))
			np.random.shuffle(idx)
			for item in idx:
				self.fs, x = wavfile.read(os.path.join(self.config.soundSource['dataset_path'], path_list[item]))
				if x.size / self.fs > self.config.soundSource['max_sound_dur']:
					continue
				self.words[i].append(x)
				if len(self.words[i]) >= loadSize[i]:
					break

	def loadSoundData_pybullet(self, datasetName, soundList, loadSize):
		"""
		This function loads Google Command Dataset, Nsynth Dataset, and UrbanSound Dataset for PyBullet env
		"""
		word_dir = os.path.join(self.config.commonMediaPath, datasetName, self.soundSource['train_test'])
		assert os.path.isdir(word_dir)

		for i in range(self.numAllObject):
			if i not in self.words:
				self.words[i]=[]
			if soundList[i] is not None:
				folderPath = os.path.join(word_dir, soundList[i])
				for j, filePath in enumerate(glob.glob(os.path.join(folderPath, '*.wav'))):
					if j >= loadSize[i]:
						break
					try:
						self.fs, x = wavfile.read(filePath)
						self.words[i].append(x)
					except:
						print(filePath)

	def get_mfcc(self, audioSamples, param):
		# calculate mfcc at run time to reduce memory usage
		sound_feat = mfcc(audioSamples, self.fs, winlen=param.windowLenTime,
						  winstep=param.windowStepTime,
						  numcep=40, nfilt=40, nfft=param.nFFT, winfunc=np.hamming)

		sound_feat = self.processSoundFeat(sound_feat)
		return sound_feat

	def genSoundFeat(self, objIndx, featType, rand_fn):
		"""
		generate sound feature according to objIndx. Each object is associate with an objIndx
		it is the case for pybullet env
		"""
		if objIndx>len(self.config.objList)-1:
			objIndx=len(self.config.objList)-1

		soundIndx = rand_fn(0, len(self.words[objIndx]), size=())

		if featType == 'MFCC':
			if soundIndx<self.wordSplit[objIndx]: # use group 0 parameters
				param=self.param_dict[self.config.soundSource['dataset'][0]]
			else: # use group 1 parameters
				param=self.param_dict[self.config.soundSource['dataset'][1]]

			audioSamples = self.words[objIndx][soundIndx]
			sound_feat=self.get_mfcc(audioSamples, param)
		else:
			raise NotImplementedError

		return sound_feat, audioSamples

	def genSoundFeatFromTask(self, task, featType, rand_fn):
		"""
		generate sound feature according to task. A task is a struct with attributes loc, obj, and act
		it is the case for ai2thor env
		"""
		soundList=self.words[task.loc][task.obj][task.act]
		soundIndx=rand_fn(0, len(soundList), size=())
		audioSamples=soundList[soundIndx]
		audioTranscription=self.transcription[task.loc][task.obj][task.act][soundIndx]

		if featType == 'MFCC':
			param=self.param_dict[self.config.soundSource['dataset']]
			sound_feat=self.get_mfcc(audioSamples, param)

		else:
			raise NotImplementedError

		return sound_feat, audioSamples, audioTranscription

	def getAudioFromTask(self, random_func, tsk, Task):
		idx=random_func.randint(low=0, high=len(self.config.synonym[tsk.loc]), size=())
		loc = self.config.synonym[tsk.loc][idx]

		idx = random_func.randint(low=0, high=len(self.config.synonym[tsk.obj]), size=())
		obj = self.config.synonym[tsk.obj][idx]

		obj_act = self.config.soundSource['obj_act'][obj]
		synonym_act = self.config.synonym[tsk.act]
		act = list(set(obj_act).intersection(synonym_act))[0]

		sound_feat, audioSamples, audioTranscription = self.genSoundFeatFromTask(task=Task(loc, obj, act),
																				 featType='MFCC',
																				 rand_fn=random_func.randint)
		return sound_feat, audioSamples, audioTranscription

	def processSoundFeat(self, sound_feat):
		sound_feat = np.expand_dims(sound_feat, axis=0)
		# process the sound
		# The observations for sound may have various length. Since the gym env accept fixed length observations,
		# we will need to set max acceptable length of sound of the env.The max acceptable length is set in the
		# self.config.sound_dim. A sound signal with length less than this max acceptable length will be padded
		# with 0. Warning: padding with zeros without telling the network with the true length
		# may decrease the performance
		nf = sound_feat.shape[1]
		if self.config.sound_dim[1] < nf:  # drop extra if the length is too long
			sound_feat = sound_feat[:, self.config.sound_dim[1], :]
		else:  # pad 0 if the length is not long enough
			zeroPadShape = list(self.config.sound_dim)
			zeroPadShape[1] = self.config.sound_dim[1] - nf
			sound_feat = np.concatenate((sound_feat, np.zeros(zeroPadShape)), axis=1)

		return sound_feat
