import torch
from torch.utils.data import Dataset
import os
import numpy as np
import librosa
import argparse
import random
from hyperparameters import args
import soundfile as sf


class SpeechDataset_attacked(Dataset):
    def __init__(self, root_dir, dataset_path: dict, noise_dir, max_frames=args.max_frames):
        self.data = []
        self.labels = []
        self.root_dir = root_dir
        self.dataset_path = dataset_path
        self.max_frames = max_frames
        self.load_data()
        self.noise_dir = noise_dir
        self.noise_tmp = os.listdir(noise_dir)
        self.noise = [e for e in self.noise_tmp if 'part' not in e]
        self.save_phone = 100
        self.save_phone_w_noise = 100

    def load_data(self):
        real_dir_w_speaker = self.dataset_path['real']["speaker"]
        fake_dir_w_speaker = self.dataset_path['fake']["speaker"]
        real_dir_wo_speaker = self.dataset_path['real']["no_speaker"]
        fake_dir_wo_speaker = self.dataset_path['fake']["no_speaker"]

        for real_dir in real_dir_w_speaker:
            real_dir = os.path.join(self.root_dir, real_dir)
            for speaker in os.listdir(real_dir):
                speaker_path = os.path.join(real_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # wav or mp3 or flac
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                '.flac') or wav_file.endswith('.m4a'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(0)
        for real_dir in real_dir_wo_speaker:
            real_dir = os.path.join(self.root_dir, real_dir)
            for wav_file in os.listdir(real_dir):
                if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                        '.flac') or wav_file.endswith('.m4a'):
                    self.data.append(os.path.join(real_dir, wav_file))
                    self.labels.append(0)

        for fake_dir in fake_dir_w_speaker:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            for speaker in os.listdir(fake_dir):
                speaker_path = os.path.join(fake_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # ratio: when to add fake data
                        if random.random() < args.fake_ratio:
                            if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                    '.flac') or wav_file.endswith('.m4a'):
                                self.data.append(os.path.join(speaker_path, wav_file))
                                self.labels.append(1)
        for fake_dir in fake_dir_wo_speaker:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            for wav_file in os.listdir(fake_dir):
                if random.random() < args.fake_ratio:
                    if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                            '.flac') or wav_file.endswith('.m4a'):
                        self.data.append(os.path.join(fake_dir, wav_file))
                        self.labels.append(1)

        # keep fake and real data balanced, replicate
        if len(self.labels) > 0:
            real_num = len([label for label in self.labels if label == 0])
            fake_num = len([label for label in self.labels if label == 1])
            if real_num > fake_num:
                # random select fake data
                fake_data = [self.data[i] for i in range(len(self.labels)) if self.labels[i] == 1]
                fake_labels = [1] * (real_num - fake_num)
                fake_data = random.choices(fake_data, k=real_num - fake_num)
                self.data += fake_data
                self.labels += fake_labels
            else:
                # random select real data
                real_data = [self.data[i] for i in range(len(self.labels)) if self.labels[i] == 0]
                real_labels = [0] * (fake_num - real_num)
                real_data = random.choices(real_data, k=fake_num - real_num)
                self.data += real_data
                self.labels += real_labels

    def __len__(self):
        return len(self.data)

    def get_noise(self, waveform):
        noise_path = os.path.join(self.noise_dir, np.random.choice(self.noise))
        noise, _ = self.loadWAV(noise_path, self.max_frames * 4)

        # get noise snr, 0 -> 0.5
        noise_snr = random.random() * args.snr
        noise_power = np.sum(noise ** 2) + 1e-8
        waveform_power = np.sum(waveform ** 2) + 1e-8
        noise = np.sqrt(waveform_power / (10 ** (noise_snr / 10))) * noise / np.sqrt(noise_power)

        return noise

    def change_sample_rate(self, waveform, orig_sr, target_sr):
        y_downsampled = librosa.core.resample(y=waveform, orig_sr=orig_sr, target_sr=target_sr)
        return y_downsampled, target_sr

    def __getitem__(self, idx):
        # try:
        wav_path = self.data[idx]
        waveform, sample_rate = self.loadWAV(wav_path, self.max_frames * 4)

        noise_random = random.random()
        downsample_random = random.random()

        noise = np.zeros_like(waveform)
        if args.noise:
            if noise_random < args.noise_ratio:
                noise = self.get_noise(waveform)

        # downsample
        if args.downsample:
            if downsample_random < args.downsample_ratio:
                sample_rate = random.uniform(args.downsample_min, 16000)
                waveform, sample_rate = self.change_sample_rate(waveform, sample_rate, sample_rate)
                noise, sample_rate = self.change_sample_rate(noise, sample_rate, sample_rate)

        # truncate to 2 * max_frames * 160 + 240, else pad
        if waveform.shape[0] > 2 * self.max_frames * 160 + 240:
            waveform = waveform[:2 * self.max_frames * 160 + 240]
            noise = noise[:2 * self.max_frames * 160 + 240]
        else:
            waveform = np.pad(waveform, (0, 2 * self.max_frames * 160 + 240 - waveform.shape[0]), 'wrap')
            noise = np.pad(noise, (0, 2 * self.max_frames * 160 + 240 - noise.shape[0]), 'wrap')

        # convert to tensor
        waveform = torch.tensor(waveform).float()
        noise = torch.tensor(noise).float()

        return waveform, noise, sample_rate, self.labels[idx]

    def loadWAV(self, filepath, max_frames):

        # Maximum audio length
        max_audio = max_frames * 160 + 240

        # Read wav file, use librosa, and convert to mono
        # audio, sample_rate = sf.read(filepath)
        audio, sample_rate = librosa.load(filepath, sr=args.audio_sample_rate)
        if audio.ndim > 1:
            audio = librosa.to_mono(audio)

        audiosize = audio.shape[0]

        if audiosize <= max_audio:
            shortage = max_audio - audiosize + 1
            audio = np.pad(audio, (0, shortage), 'wrap')
            audiosize = audio.shape[0]

        startframe = np.int64(np.random.rand() * (audiosize - max_audio))
        audio = audio[startframe:startframe + max_audio]

        return audio, sample_rate


if __name__ == '__main__':
    # dataset
    dataset_path = {
        'real': {
            'speaker': ["deepfake_VCTK/source/",
                        "deepfake_in_the_wild/bona_fide_audio/"],
            'no_speaker': []
        },
        'fake': {
            'speaker': ["deepfake_VCTK/metavoice/",
                        "deepfake_in_the_wild/bona_fide_tts/",
                        "deepfake_in_the_wild/bona_fide_metavoice_tts/"],
            'no_speaker': []
        }
    }
    noise_dir = "/local/rcs/zz3093/data/noise/"
    root_dir = "/local/rcs/zz3093/data/"
    dataset = SpeechDataset_attacked(root_dir, dataset_path, noise_dir)
