import os
import pickle
import numpy as np
import soundfile as sf
from scipy import signal
from scipy.signal import get_window
from librosa.filters import mel
from numpy.random import RandomState

#该函数获得滤波器的系数。
def butter_highpass(cutoff, fs, order=5):#(30,16000,5)
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq#一个值，表示高通滤波
    b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
    return b, a
    
    
def pySTFT(x, fft_length=1024, hop_length=256):
    
    x = np.pad(x, int(fft_length//2), mode='reflect')
    
    noverlap = fft_length - hop_length#非重合部分
    shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
    strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
    #实现矩阵的分块操作
    result = np.lib.stride_tricks.as_strided(x, shape=shape,
                                             strides=strides)
    #加窗，最终返回一个fft_length的array
    fft_window = get_window('hann', fft_length, fftbins=True)
    result = np.fft.rfft(fft_window * result, n=fft_length).T
    
    return np.abs(result)    
    
    
mel_basis = mel(16000, 1024, fmin=90, fmax=7600, n_mels=80).T
min_level = np.exp(-100 / 20 * np.log(10))
b, a = butter_highpass(30, 16000, order=5)


# audio file directory
root_dir = '/home/ttsdev/nastts/resample_wav'
# spectrogram directory
targetDir = '/home/ttsdev/nastts/AVCT/vctk_mel'


dirName, subdirList, _ = next(os.walk(root_dir))
print('Found directory: %s' % dirName)

for subdir in sorted(subdirList):
    print(subdir)
    if not os.path.exists(os.path.join(targetDir, subdir)):
        os.makedirs(os.path.join(targetDir, subdir))
    _, wav_dir, fileList = next(os.walk(os.path.join(dirName, subdir)))
    random_seed = RandomState(int(len(subdir[1:]))) #设置一个随机数种子，当随机数种子相同，其产生的随机数序列也是相同的。
    for fileName in sorted(fileList):
        #Read audio file
        x, fs = sf.read(os.path.join(dirName, subdir, fileName))
        # Remove drifting noise
        y = signal.filtfilt(b, a, x)
        # Ddd a little random noise for model roubstness
        wav = y * 0.96 + (random_seed.rand(y.shape[0])-0.5)*1e-06
        #print("测试wav形状：",wav.shape)
        # Compute spect
        D = pySTFT(wav).T
        print("测试D的形状：",D.shape)
        print("测试mel_basis:",mel_basis.shape)
        # Convert to mel and normalize
        #实现两个向量的点积
        D_mel = np.dot(D, mel_basis)
        D_db = 20 * np.log10(np.maximum(min_level, D_mel)) - 16
        S = np.clip((D_db + 100) / 100, 0, 1)
        #save spect
        np.save(os.path.join(targetDir, subdir, fileName[:-4]),
                S.astype(np.float32), allow_pickle=False)

