import librosa
import numpy as np
from regression.load_meg_targets import load_meg_targets
from tqdm import tqdm

def load_spectrogram(config, feature_sample_rate=50, frame_length_ms = 25, n_mels = 80):
    ms_per_sample = 1000/feature_sample_rate
    audio, sr = config.load_aligned_audio(16000)
    hop_length = int(sr * ms_per_sample / 1000)
    n_fft = int(sr * frame_length_ms / 1000)
    S = librosa.feature.melspectrogram(
                                    y=audio,
                                    sr=sr,
                                    n_fft=n_fft,
                                    hop_length=hop_length,
                                    n_mels=n_mels,        # number of mel bands you like
                                    fmin=0,
                                    fmax=sr // 2,
                                    power=2.0         # power=2.0 gives you a power spectrogram
                                )
    S_dB = librosa.power_to_db(S, ref=np.max)
    meg_size = len(load_meg_targets([config])[0])
    out_features = S_dB[:,:meg_size].T
    #normalized_out_features = (out_features - np.mean(out_features, axis=0)[None,:])/np.std(out_features, axis=0)[None,:]
    return out_features

def delayed_spectrogram_features(configs, num_delays = 40, fold_tensors = True, zero_out_delays=25):
    out = []
    for config in tqdm(configs):
        spec = load_spectrogram(config, 50, 25)
        normalized_spec = (spec - np.mean(spec, axis = 0)[None,:])/np.std(spec, axis=0)[None,:]
        time_delay_spec = np.zeros((spec.shape[0], spec.shape[1], num_delays))
        padding = np.zeros((num_delays-1, spec.shape[1]))
        padded_spectrogram = np.concat((padding, normalized_spec), axis = 0)
        for t in range(len(spec)):
            time_delay_spec[t] = padded_spectrogram[t:t+num_delays].T
        if not zero_out_delays is None:
            time_delay_spec[:,:,:zero_out_delays] = np.zeros_like(time_delay_spec[:,:,:zero_out_delays])
        if fold_tensors:
            time_delay_spec = time_delay_spec.reshape(spec.shape[0],-1)
        out.append(time_delay_spec)
    return out
