import os
import fire
import pickle

from scipy.signal.windows import hann
from scipy.signal import ShortTimeFFT as STFFT
from scipy.interpolate import RegularGridInterpolator
import yaml
import glob
import numpy as np
from scipy.io import wavfile

X_SHAPE = (64,64)
EPSILON = 1e-12
SPEC_DIM = 64*64

def get_spec(onset, offset, audio, max_dur=None, fs=32000, target_size=(128,128), \
    fill_value=-1/EPSILON, remove_dc_offset=True,spec_min_val=0.5,spec_max_val=2.0,\
        fft_params:dict={'nperseg':512,
                           'noverlap':256,
                        'min_freq':100,
                        'max_freq':15000},
            interp=False):
    """
    make a spectrogram for a given vocalization.
    """
    shoulder = 0.05
    flag = 0
    if offset - onset <= 0.05:
        flag = 1
    audiotimes = np.linspace(0,len(audio)/fs,len(audio))
    onInd = int(round(onset *fs))
    offInd = int(round(offset*fs))
    a,t = audio[onInd:offInd],\
        np.linspace(onset,offset,offInd - onInd)
    N = len(t)
    w = hann(fft_params['nperseg'], sym=True)  # symmetric Gaussian window

    transform = STFFT(w,hop=fft_params['nperseg'] - fft_params['noverlap'],fs=fs,mfft = fft_params['nperseg'])

    Sx = transform.stft(a)
    
    t_lo,t_hi,f_lo,f_hi = transform.extent(N)
    tAx = np.linspace(t_lo,t_hi,Sx.shape[1]) + onset
    
    t0,t1 = np.searchsorted(tAx,onset),np.searchsorted(tAx,offset)
    Sx,tAx = Sx[:,t0:t1],tAx[t0:t1]

    fAx = np.linspace(f_lo,f_hi,Sx.shape[0])
    #f0,f1 = np.searchsorted(fAx,fft_params['min_freq']),np.searchsorted(fAx,fft_params['max_freq'])
    #Sx,fAx = Sx[f0:f1,:],fAx[f0:f1]


    target_freqs = np.linspace(_mel(fft_params['min_freq']),_mel(fft_params['max_freq']),target_size[0])
    target_freqs = _inv_mel(target_freqs)
    target_ts = np.linspace(tAx[0],tAx[-1],target_size[1])

    
    if interp:
        interp = RegularGridInterpolator((fAx,tAx),Sx,bounds_error=True,fill_value=fill_value)

        newX,newY = np.meshgrid(target_freqs,target_ts,indexing='ij',sparse=True)

        Sx2 = interp((newX,newY))
        Sx = Sx2
    else:
        f0,f1 = np.searchsorted(fAx,fft_params['min_freq']),np.searchsorted(fAx,fft_params['max_freq'])
        target_ts,target_freqs = tAx,fAx[f0:f1]
        Sx = Sx[f0:f1,:]

    Sx = np.log(np.abs(Sx) + 1e-12)
    Sx = (Sx - spec_min_val) / (spec_max_val - spec_min_val)
    Sx = np.clip(Sx, 0.0, 1.0)

    return Sx,target_ts,target_freqs,flag

def _mel(a):

    return 1127 * np.log(1 + a/700)

def _inv_mel(a):
    return 700 * (np.exp(a/1127) - 1) 


def make_specs(path_to_data='.',path_to_save='.',windowSize=0.1,overlap=0.8,sample_len_s=0.6):

    audio_path= os.path.join(path_to_data,'audio')
    seg_path = os.path.join(path_to_data,'labeled_syllables')

    with open(os.path.join(path_to_data,'spec_params.yml'), 'r') as file:
        spec_params = yaml.load(file,Loader=yaml.BaseLoader)#,loader=)


    spec_min_val=1.#float(spec_params['spec_min_val'])
    spec_max_val=9.5#float(spec_params['spec_max_val']) + 1
    fft_len=int(spec_params['nperseg'])
    hop=int(spec_params['nperseg']) - int(spec_params['noverlap'])
    min_freq=100
    max_freq=18000
    
    days = glob.glob(os.path.join(audio_path,'2[0-9]*'))
    days.sort()
    day_nums = [d.split('/')[-1] for d in days]

    all_specs = []
    for d in day_nums:
        day_specs = []
        audio_files = glob.glob(os.path.join(audio_path,d,'*.wav'))
        tags = [a.split('/')[-1].split('.wav')[0] + '.txt' for a in audio_files]
        seg_files = [os.path.join(seg_path,d,t) for t in tags]

        for a,s in zip(audio_files,seg_files):

            onoffs = np.loadtxt(s)
            sr,a = wavfile.read(a)

            if len(onoffs.shape) == 1:
                onoffs = onoffs[None,:]

            for on, off in onoffs:
                vocOffset = on+sample_len_s
                if vocOffset > len(a)/sr:
                    continue

                tmp_ons = np.arange(on,vocOffset - windowSize,windowSize*(1-overlap))
                tmp_offs = tmp_ons + windowSize
                specs=[]
                for t_on, t_off in zip(tmp_ons,tmp_offs):
                    spec, ts,fs,flag = get_spec(t_on,t_off, a,fs=sr,\
                                fft_params={'nperseg': fft_len,'noverlap':hop,'min_freq':min_freq,'max_freq':max_freq},
                                spec_max_val=spec_max_val,spec_min_val=spec_min_val,interp=True)
                    specs.append(spec)
                specs = np.stack(specs)
                day_specs.append(specs)
        all_specs.append(day_specs)

    for d,day_num in zip(all_specs,day_nums):

        save_fn = os.path.join(path_to_save,str(day_num) + f'_{int(windowSize*1000)}mswindow_{int(overlap*windowSize*1000):d}msoverlap_data_A_aligned.pkl')
        print(save_fn)
        
        with open(save_fn,'wb') as f:
            pickle.dump(d,f)

    return all_specs


if __name__ == '__main__':

    specs = fire.Fire(make_specs)
