import torch
from torch.utils.data import Dataset
import os
import numpy as np
import librosa
import argparse
import random
from hyperparameters import args
import soundfile as sf


class SpeechDataset(Dataset):
    def __init__(self, root_dir, dataset_path: dict, noise_dir, max_frames=args.max_frames):
        self.data = []
        self.labels = []
        self.root_dir = root_dir
        self.dataset_path = dataset_path
        self.max_frames = max_frames
        self.load_data()
        self.noise_dir = noise_dir
        self.noise_tmp = os.listdir(noise_dir)
        self.noise = [e for e in self.noise_tmp if 'part' not in e]
        self.save_phone = 100
        self.save_phone_w_noise = 100

    def load_data(self):
        real_dir_w_speaker = self.dataset_path['real']["speaker"]
        fake_dir_w_speaker = self.dataset_path['fake']["speaker"]
        real_dir_wo_speaker = self.dataset_path['real']["no_speaker"]
        fake_dir_wo_speaker = self.dataset_path['fake']["no_speaker"]

        for real_dir in real_dir_w_speaker:
            real_dir = os.path.join(self.root_dir, real_dir)
            for speaker in os.listdir(real_dir):
                speaker_path = os.path.join(real_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # wav or mp3 or flac
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                '.flac') or wav_file.endswith('.m4a'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(0)
        for real_dir in real_dir_wo_speaker:
            real_dir = os.path.join(self.root_dir, real_dir)
            for wav_file in os.listdir(real_dir):
                if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                        '.flac') or wav_file.endswith('.m4a'):
                    self.data.append(os.path.join(real_dir, wav_file))
                    self.labels.append(0)

        for fake_dir in fake_dir_w_speaker:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            for speaker in os.listdir(fake_dir):
                speaker_path = os.path.join(fake_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # ratio: when to add fake data
                        if random.random() < args.fake_ratio:
                            if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                    '.flac') or wav_file.endswith('.m4a'):
                                self.data.append(os.path.join(speaker_path, wav_file))
                                self.labels.append(1)
        for fake_dir in fake_dir_wo_speaker:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            for wav_file in os.listdir(fake_dir):
                if random.random() < args.fake_ratio:
                    if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                            '.flac') or wav_file.endswith('.m4a'):
                        self.data.append(os.path.join(fake_dir, wav_file))
                        self.labels.append(1)

        # keep fake and real data balanced, replicate
        if len(self.labels) > 0:
            real_num = len([label for label in self.labels if label == 0])
            fake_num = len([label for label in self.labels if label == 1])
            if real_num > fake_num:
                # random select fake data
                fake_data = [self.data[i] for i in range(len(self.labels)) if self.labels[i] == 1]
                fake_labels = [1] * (real_num - fake_num)
                fake_data = random.choices(fake_data, k=real_num - fake_num)
                self.data += fake_data
                self.labels += fake_labels
            else:
                # random select real data
                real_data = [self.data[i] for i in range(len(self.labels)) if self.labels[i] == 0]
                real_labels = [0] * (fake_num - real_num)
                real_data = random.choices(real_data, k=fake_num - real_num)
                self.data += real_data
                self.labels += real_labels

    def __len__(self):
        return len(self.data)

    def add_noise(self, waveform):
        noise_path = os.path.join(self.noise_dir, np.random.choice(self.noise))
        noise, _ = self.loadWAV(noise_path, self.max_frames * 4)

        # get noise snr, 0 -> 0.5
        noise_snr = random.random() * args.snr
        noise_power = np.sum(noise ** 2) + 1e-8
        waveform_power = np.sum(waveform ** 2) + 1e-8
        noise = np.sqrt(waveform_power / (10 ** (noise_snr / 10))) * noise / np.sqrt(noise_power)

        # truncate noise
        if len(noise) > len(waveform):
            noise = noise[:len(waveform)]
        else:
            noise = np.pad(noise, (0, len(waveform) - len(noise)))

        waveform += noise
        return waveform

    def change_sample_rate(self, waveform, orig_sr, target_sr):
        y_downsampled = librosa.core.resample(y=waveform, orig_sr=orig_sr, target_sr=target_sr)
        return y_downsampled, target_sr

    def time_stertch(self, waveform, factor):
        y_stretched = librosa.effects.time_stretch(waveform, rate=factor)
        return y_stretched

    def pitch_shift(self, waveform, sample_rate, n_steps):
        y_pitch_shifted = librosa.effects.pitch_shift(waveform, sr=sample_rate, n_steps=n_steps)
        return y_pitch_shifted

    def add_reverb(self, waveform, sample_rate):
        pass

    def enhance_speech(self, waveform, sample_rate):
        pass

    def add_phone_call(self, waveform, sample_rate):
        origin_length = waveform.shape[0]
        EIGHT_KHZ = 16000
        SPEECH_UPPER_BAND = 7000
        n_fft = 1024  # FFT size, adjust as needed

        # double n_fft
        if random.random() < 0.5:
            n_fft = n_fft * 2

        # 20 -> 80
        SPEECH_LOW_BAND = random.randint(args.phone_low_pass_min, args.phone_low_pass_max)

        SPEECH_LOW_BAND = int(SPEECH_LOW_BAND * n_fft / sample_rate)
        SPEECH_UPPER_BAND = int(SPEECH_UPPER_BAND * n_fft / sample_rate)

        resampled_time_series = librosa.core.resample(y=waveform, orig_sr=sample_rate, target_sr=EIGHT_KHZ)

        short_time_fourier_transform = librosa.core.stft(y=resampled_time_series, n_fft=n_fft)

        # Clear out speech bands
        # TODO(jonluca) is there a better way of doing this? Feels hacky
        short_time_fourier_transform[:SPEECH_LOW_BAND] = 0
        short_time_fourier_transform[SPEECH_UPPER_BAND: len(short_time_fourier_transform)] = 0

        reconstructed_time_series = librosa.core.istft(short_time_fourier_transform)

        # truncate of padding to keep the same length
        if len(reconstructed_time_series) > origin_length:
            reconstructed_time_series = reconstructed_time_series[:origin_length]
        else:
            reconstructed_time_series = np.pad(reconstructed_time_series,
                                               (0, origin_length - len(reconstructed_time_series)))

        return reconstructed_time_series, EIGHT_KHZ

    def __getitem__(self, idx):
        # print("len of data: ", len(self.data))
        # try:
        wav_path = self.data[idx]
        waveform, sample_rate = self.loadWAV(wav_path, self.max_frames * 4)

        noise_random = random.random()
        phone_random = random.random()
        time_stretch_random = random.random()
        pitch_shift_random = random.random()
        downsample_random = random.random()

        # only attack fake data
        if self.labels[idx] == 1:
            # time stretch
            if args.time_stretch:
                if time_stretch_random < args.time_stretch_ratio:
                    factor = random.uniform(args.time_stretch_min, args.time_stretch_max)
                    waveform = self.time_stertch(waveform, factor)

            # pitch shift, if time stretch is applied, continue
            if args.pitch_shift:
                if time_stretch_random >= args.time_stretch_ratio and pitch_shift_random < args.pitch_shift_ratio:
                    n_steps = random.uniform(args.pitch_shift_min, args.pitch_shift_max)
                    waveform = self.pitch_shift(waveform, sample_rate, n_steps)

        if args.noise:
            if noise_random < args.noise_ratio:
                waveform = self.add_noise(waveform)

        if args.phone:
            if phone_random < args.phone_ratio:
                waveform, sample_rate = self.add_phone_call(waveform, sample_rate)

        # downsample
        if args.downsample:
            if downsample_random < args.downsample_ratio:
                sample_rate = random.uniform(args.downsample_min, 16000)
                waveform, sample_rate = self.change_sample_rate(waveform, sample_rate, sample_rate)

        # truncate to max_frames * 160 + 240
        if waveform.shape[0] > self.max_frames * 160 + 240:
            waveform = waveform[:self.max_frames * 160 + 240]

        # save phone call
        # os.makedirs("saved_output", exist_ok=True)

        # if (args.phone and noise_random > args.noise_ratio and phone_random < args.phone_ratio
        #         and self.save_phone > 0 and random.random() > 0.95):
        #     sf.write(f"saved_output/phone_call_{idx}.wav", waveform, sample_rate)
        #     self.save_phone -= 1
        #
        # # save noisy phone call
        # if (args.phone and noise_random < args.noise_ratio and phone_random < args.phone_ratio
        #         and self.save_phone_w_noise > 0 and random.random() > 0.95):
        #     sf.write(f"saved_output/phone_call_noise_{idx}.wav", waveform, sample_rate)
        #     self.save_phone_w_noise -= 1

        return torch.tensor(waveform, dtype=torch.float32), self.labels[idx]

    def loadWAV(self, filepath, max_frames):

        # Maximum audio length
        max_audio = max_frames * 160 + 240

        # Read wav file, use librosa, and convert to mono
        # audio, sample_rate = sf.read(filepath)
        audio, sample_rate = librosa.load(filepath, sr=args.audio_sample_rate)
        if audio.ndim > 1:
            audio = librosa.to_mono(audio)

        audiosize = audio.shape[0]

        if audiosize <= max_audio:
            shortage = max_audio - audiosize + 1
            audio = np.pad(audio, (0, shortage), 'wrap')
            audiosize = audio.shape[0]

        startframe = np.int64(np.random.rand() * (audiosize - max_audio))
        audio = audio[startframe:startframe + max_audio]

        return audio, sample_rate


if __name__ == '__main__':
    # dataset
    dataset_path = {
        'real': {
            'speaker': ["deepfake_VCTK/source/",
                        "deepfake_in_the_wild/bona_fide_audio/"],
            'no_speaker': []
        },
        'fake': {
            'speaker': ["deepfake_VCTK/metavoice/",
                        "deepfake_in_the_wild/bona_fide_tts/",
                        "deepfake_in_the_wild/bona_fide_metavoice_tts/"],
            'no_speaker': []
        }
    }
    noise_dir = "/local/rcs/zz3093/data/noise/"
    root_dir = "/local/rcs/zz3093/data/"
    dataset = SpeechDataset(root_dir, dataset_path, noise_dir)
