# Third Party
import scipy
import scipy.fftpack
import soundfile as sf
import torch
import librosa
import random
import torch.nn.functional as F
import numpy as np

# ===============================================
#       code from Arsha for loading data.
# This code extract features for a give audio file
# ===============================================
from g_lfcc import linear_fbank, extract_lfcc


def load_wav(audio_filepath, sr, min_dur_sec=4):
    audio_data, fs = librosa.load(audio_filepath, sr=16000)
    len_file = len(audio_data)

    if len_file < int(min_dur_sec * sr):
        dummy = np.zeros((1, int(min_dur_sec * sr) - len_file))
        extened_wav = np.concatenate((audio_data, dummy[0]))
    else:

        extened_wav = audio_data
    return extened_wav


def sf_read(audio_filepath, sr, min_dur_sec=4):
    audio_data, sr = sf.read(audio_filepath)
    len_file = len(audio_data)

    if len_file < int(min_dur_sec * sr):
        dummy = np.zeros((1, int(min_dur_sec * sr) - len_file))
        extened_wav = np.concatenate((audio_data, dummy[0]))
    else:
        extened_wav = audio_data
    return extened_wav, sr


def lin_mel_from_wav(wav, sr, hop_length, win_length, n_fft=512):
    linear = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, win_length=win_length,
                                            hop_length=hop_length)  # linear spectrogram
    return linear.T


def lin_spectogram_from_wav(wav, hop_length, win_length, n_fft=512):
    linear = librosa.stft(y=wav, n_fft=n_fft, win_length=win_length, hop_length=hop_length)  # linear spectrogram
    return linear.T


def lin_mfcc_from_wav(wav, n_mels=40):
    linear = librosa.feature.mfcc(y=wav, n_mfcc=n_mels)  # linear spectrogram
    return linear.T


def lin(sr, n_fft, n_filter=128, fmin=0.0, fmax=None, dtype=np.float32):
    if fmax is None:
        fmax = float(sr) / 2
    # Initialize the weights
    n_filter = int(n_filter)
    weights = np.zeros((n_filter, int(1 + n_fft // 2)), dtype=dtype)

    # Center freqs of each FFT bin
    fftfreqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

    # 'Center freqs' of liner bands - uniformly spaced between limits
    linear_f = np.linspace(fmin, fmax, n_filter + 2)

    fdiff = np.diff(linear_f)
    ramps = np.subtract.outer(linear_f, fftfreqs)

    for i in range(n_filter):
        # lower and upper slopes for all bins
        lower = -ramps[i] / fdiff[i]
        upper = ramps[i + 2] / fdiff[i + 1]

        # .. then intersect them with each other and zero
        weights[i] = np.maximum(0, np.minimum(lower, upper))

    return weights


def linear_spec(y=None,
                sr=22050,
                n_fft=2048,
                hop_length=512,
                win_length=None,
                window='hann',
                center=True,
                pad_mode='reflect',
                power=2.0,
                **kwargs):
    S = np.abs(
        librosa.core.stft(y=y,
                          n_fft=n_fft,
                          hop_length=hop_length,
                          win_length=win_length,
                          window=window,
                          center=center,
                          pad_mode=pad_mode)) ** power
    filter = lin(sr=sr, n_fft=n_fft, **kwargs)
    return np.dot(filter, S)


def lfcc(y=None,
         sr=22050,
         S=None,
         n_lfcc=40,
         dct_type=2,
         norm='ortho',
         **kwargs):
    if S is None:
        S = librosa.power_to_db(linear_spec(y=y, sr=sr, **kwargs))
    M = scipy.fftpack.dct(S, axis=0, type=dct_type, norm=norm)[:n_lfcc]
    return M


def feature_extraction(filepath, sr=24000, min_dur_sec=4, win_length=400, hop_length=160, spec_len=500, mode='train'):
    # audio_data = load_wav(filepath, sr=sr, min_dur_sec=min_dur_sec)
    # linear_spect = lin_spectogram_from_wav(audio_data, hop_length, win_length, n_fft=512)
    # linear_spect = lin_mel_from_wav(audio_data, sr, win_length, hop_length)
    # linear_spect = lin_mfcc_from_wav(audio_data)  # n_mfcc,t的转置 t，n_mfcc 提取MFCC特征
    utt, sr = sf_read(filepath, sr)
    lfcc_fb = linear_fbank(sample_rate=sr)
    spec = extract_lfcc(utt, lfcc_fb)
    mag, _ = librosa.magphase(spec)  # magnitude  t，n_mfcc
    mag_T = mag.T  # n_mfcc,t
    mu = np.mean(mag_T, 0, keepdims=True)
    std = np.std(mag_T, 0, keepdims=True)
    spec_mag = (mag_T - mu) / (std + 1e-5)
    feats = spec_mag[:min(spec_len, mag_T.shape[0]), :]
    #print(feats.shape)
    return feats


def load_data(filepath, sr=24000, min_dur_sec=4, win_length=400, hop_length=160, n_mels=40, spec_len=400, mode='train'):
    # audio_data = load_wav(filepath, sr=sr, min_dur_sec=min_dur_sec)
    utt, sr = sf_read(filepath, sr)
    lfcc_fb = linear_fbank(sample_rate=sr)
    spec = extract_lfcc(utt, lfcc_fb)
    #print(f"spec:{spec.shape}")
    mag, _ = librosa.magphase(spec)  # magnitude
    mag_T = mag.T

    # preprocessing, subtract mean, divided by time-wise var
    mu = np.mean(mag_T, 0, keepdims=True)
    std = np.std(mag_T, 0, keepdims=True)
    spec_mag = (mag_T - mu) / (std + 1e-5)
    feature_length = min(spec_len, mag_T.shape[0])
    feats = spec_mag[:feature_length, :]
    #print(feats.shape)
    return feats, feature_length


def load_npy_data(filepath, spec_len=400, mode='train'):
    mag_T = np.load(filepath)
    # if mode=='train':
    # randtime = np.random.randint(0, mag_T.shape[1]-spec_len)
    # spec_mag = mag_T[:, randtime:randtime+spec_len]

    # else:
    # spec_mag = mag_T[:, :spec_len]
    spec_mag = mag_T[:, :min(spec_len, mag_T.shape[1])]
    feature_length = min(spec_len, mag_T.shape[1])
    return spec_mag, feature_length


def speech_collate(batch):
    targets = []
    specs = []
    features_length = [data['features_length'] for data in batch]
    # print(features_length)
    max_feature_length = max(features_length)
    for sample in batch:
        feat_len = sample['features_length']
        padding_feature_len = max_feature_length - feat_len
        feat = sample['features']
        # print(feat.shape)
        a = F.pad(feat, pad=(0, padding_feature_len), value=0.0)
        specs.append(a)
        targets.append((sample['labels']))
        # print(a.shape)
    return specs, targets
