import os.path
from scipy import signal
import librosa
import soundfile as sf
import torch.utils.data
import glob
import os
import numpy as np
import random


LABELS_MAPPING = {
    "NEU": 0,
    "HAP": 1,
    "SAD": 2,
    "FEA": 3,
    "ANG": 4,
    "DIS": 5
}


class CremaDDataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.config = config['cremad']
        self.data = list(glob.glob(os.path.join(self.config['path'], "*.wav")))
        self.labels = [LABELS_MAPPING[x.split('/')[-1].split('_')[2]] for x in self.data]
        self.FIXED_LENGTH = 40000
        self.training = True

    def __getitem__(self, index):
        raw_signal, fs = sf.read(self.data[index])
        raw_signal = self.get_fix_length(raw_signal)

        if self.config['noise_augment'] and self.training:
            if random.uniform(0, 1) < self.config['noise_augment_chance']:
                target_snr_db = np.random.randint(10, 35)
                sig_avg_watts = np.mean(raw_signal ** 2)
                sig_avg_db = 10 * np.log10(sig_avg_watts + 1e8)
                noise_avg_db = sig_avg_db - target_snr_db
                noise_avg_watts = 10 ** (noise_avg_db / 10)
                mean_noise = 0
                noise = np.random.normal(mean_noise, np.sqrt(noise_avg_watts), raw_signal.shape)
                raw_signal = raw_signal + noise

        if self.config['time_augment'] and self.training:
            if random.uniform(0, 1) < self.config['time_augment_chance']:
                raw_signal = np.roll(raw_signal, random.randint(0, int(self.FIXED_LENGTH / 3)))

        spec = signal.stft(raw_signal, nfft=512, fs=fs, nperseg=400, noverlap=245, window='hamming',
                           return_onesided=True, padded=False, boundary=None)[2]
        if self.config['spec_augm'] and self.training:
            if random.uniform(0, 1) < self.config['spec_chance']:
                spec = self.spec_augment(spec)

        mel = librosa.feature.melspectrogram(S=np.abs(spec), sr=fs, n_mels=128)
        mel = librosa.power_to_db(mel, ref=np.max)
        mel = (np.clip(mel, a_min=-50, a_max=0) + 50) / 50.

        return np.expand_dims(mel, 0), self.labels[index]

    def get_fix_length(self, raw_signal):
        if len(raw_signal) < self.FIXED_LENGTH:
            padding = np.zeros(np.abs(len(raw_signal) - self.FIXED_LENGTH))
            raw_signal = np.concatenate((raw_signal, padding))
        elif len(raw_signal) > self.FIXED_LENGTH and self.training is True:
            len_difference = np.abs(len(raw_signal) - self.FIXED_LENGTH)
            rand_idx = np.random.randint(0, len_difference - 1)
            raw_signal = raw_signal[rand_idx:rand_idx+self.FIXED_LENGTH]

        return raw_signal

    def spec_augment(self, spectrogram):
        chance = random.uniform(0, 1)

        if chance < 0.33:
            # Temporal augmentation
            temp_length = np.random.randint(0, spectrogram.shape[1] // 10)
            start_index = np.random.randint(0, spectrogram.shape[1] - temp_length - 1)
            spectrogram[start_index:(start_index + temp_length), :] = 0

        elif chance > 0.66:
            # Frecv augmentation
            freq_length = np.random.randint(0, spectrogram.shape[0] // 10)
            start_index = np.random.randint(0, spectrogram.shape[0] - freq_length - 1)
            spectrogram[:, start_index:(start_index + freq_length)] = 0
        else:
            # Both augmentations
            temp_length = np.random.randint(0, spectrogram.shape[1] // 10)
            start_index = np.random.randint(0, spectrogram.shape[1] - temp_length - 1)
            spectrogram[start_index:(start_index + temp_length), :] = 0
            freq_length = np.random.randint(0, spectrogram.shape[0] // 10)
            start_index = np.random.randint(0, spectrogram.shape[0] - freq_length - 1)
            spectrogram[:, start_index:(start_index + freq_length)] = 0

        return spectrogram

    def eval(self):
        self.training = False

    def train(self):
        self.training = True

    def __len__(self):
        return len(self.data)
