import torch
from torch.utils.data import Dataset
import os
import numpy as np
import librosa
import argparse
import random
from hyperparameters import args
import soundfile as sf
from audiomentations import *


class SpeechAugDataset(Dataset):
    def __init__(self, root_dir, dataset_path: dict, noise_dir, max_frames=args.max_frames):
        self.data = []
        self.labels = []
        self.root_dir = root_dir
        self.dataset_path = dataset_path
        self.max_frames = max_frames
        self.load_data()
        self.noise_dir = noise_dir
        self.noise_tmp = os.listdir(noise_dir)
        self.noise = [e for e in self.noise_tmp if 'part' not in e]

    def load_data(self):
        real_dir_w_speaker = self.dataset_path['real']["speaker"]
        fake_dir_w_speaker = self.dataset_path['fake']["speaker"]
        real_dir_wo_speaker = self.dataset_path['real']["no_speaker"]
        fake_dir_wo_speaker = self.dataset_path['fake']["no_speaker"]

        for real_dir in real_dir_w_speaker:
            real_dir = os.path.join(self.root_dir, real_dir)
            for speaker in os.listdir(real_dir):
                speaker_path = os.path.join(real_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # wav or mp3 or flac
                        if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                '.flac') or wav_file.endswith('.m4a'):
                            self.data.append(os.path.join(speaker_path, wav_file))
                            self.labels.append(0)
        for real_dir in real_dir_wo_speaker:
            real_dir = os.path.join(self.root_dir, real_dir)
            for wav_file in os.listdir(real_dir):
                if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                        '.flac') or wav_file.endswith('.m4a'):
                    self.data.append(os.path.join(real_dir, wav_file))
                    self.labels.append(0)

        for fake_dir in fake_dir_w_speaker:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            for speaker in os.listdir(fake_dir):
                speaker_path = os.path.join(fake_dir, speaker)
                if os.path.isdir(speaker_path):
                    for wav_file in os.listdir(speaker_path):
                        # ratio: when to add fake data
                        if random.random() < args.fake_ratio:
                            if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                                    '.flac') or wav_file.endswith('.m4a'):
                                # print is endwith mp3
                                # if wav_file.endswith('.mp3'):
                                #     print(os.path.join(speaker_path, wav_file))
                                self.data.append(os.path.join(speaker_path, wav_file))
                                self.labels.append(1)
        for fake_dir in fake_dir_wo_speaker:
            fake_dir = os.path.join(self.root_dir, fake_dir)
            for wav_file in os.listdir(fake_dir):
                if random.random() < args.fake_ratio:
                    if wav_file.endswith('.wav') or wav_file.endswith('.mp3') or wav_file.endswith(
                            '.flac') or wav_file.endswith('.m4a'):
                        # if wav_file.endswith('.mp3'):
                        #     print(os.path.join(speaker_path, wav_file))
                        self.data.append(os.path.join(fake_dir, wav_file))
                        self.labels.append(1)

        # keep fake and real data balanced, replicate
        if len(self.labels) > 0:
            real_num = len([label for label in self.labels if label == 0])
            fake_num = len([label for label in self.labels if label == 1])
            if real_num > fake_num:
                # random select fake data
                fake_data = [self.data[i] for i in range(len(self.labels)) if self.labels[i] == 1]
                fake_labels = [1] * (real_num - fake_num)
                fake_data = random.choices(fake_data, k=real_num - fake_num)
                self.data += fake_data
                self.labels += fake_labels
            else:
                # random select real data
                real_data = [self.data[i] for i in range(len(self.labels)) if self.labels[i] == 0]
                real_labels = [0] * (fake_num - real_num)
                real_data = random.choices(real_data, k=fake_num - real_num)
                self.data += real_data
                self.labels += real_labels

    def __len__(self):
        return len(self.data)

    def aug_lists(self, background_noise_path):
        aug = [
            AddBackgroundNoise(sounds_path=background_noise_path, min_snr_in_db=3.0, max_snr_in_db=30.0,
                               p=args.aug_prob),
            AddColorNoise(p=args.aug_prob),
            AddGaussianSNR(p=args.aug_prob),
            AddShortNoises(sounds_path=background_noise_path, p=args.aug_prob),
            AirAbsorption(p=args.aug_prob),
            Aliasing(min_sample_rate=8000, max_sample_rate=16000, p=args.aug_prob),
            BandPassFilter(min_center_freq=100.0, max_center_freq=6000.0, p=args.aug_prob),
            BandStopFilter(p=args.aug_prob),
            BitCrush(p=args.aug_prob),
            Clip(a_min=-0.3, a_max=0.3, p=args.aug_prob),
            ClippingDistortion(min_percentile_threshold=10, max_percentile_threshold=40, p=args.aug_prob),
            Gain(min_gain_in_db=-12, max_gain_in_db=12, p=args.aug_prob),
            GainTransition(p=args.aug_prob, min_duration=2.0, max_duration=4.0, duration_unit="seconds"),
            HighPassFilter(p=args.aug_prob),
            HighShelfFilter(p=args.aug_prob),
            Limiter(p=args.aug_prob),
            LowPassFilter(p=args.aug_prob),
            LowShelfFilter(p=args.aug_prob),
            PeakingFilter(p=args.aug_prob),
            PolarityInversion(p=args.aug_prob),
            RoomSimulator(p=args.aug_prob),
            SevenBandParametricEQ(p=args.aug_prob),
            Shift(p=args.aug_prob),
            TanhDistortion(p=args.aug_prob),
            TimeMask(p=args.aug_prob),
            TimeStretch(p=args.aug_prob, min_rate=0.6, max_rate=1.5),

            # Mp3Compression(p=1.0),
            # RepeatPart(p=1.0),
            # Trim(p=1.0),
            # Padding(p=1.0),
            # Reverse(p=1.0), Do not use reverse
        ]
        return aug

    def loadWAV(self, filepath, max_frames):

        # Maximum audio length
        max_audio = max_frames * 160

        # Read wav file, use librosa, and convert to mono
        # audio, sample_rate = sf.read(filepath)
        audio, sample_rate = librosa.load(filepath, sr=None)
        if audio.ndim > 1:
            audio = librosa.to_mono(audio)

        # background_noise_path, random select noise
        background_noise_path = os.path.join(self.noise_dir, random.choice(self.noise))

        # audio augmentation
        aug = self.aug_lists(background_noise_path)

        if args.aug_num >= 1:
            # random select augmentation, num = args.aug_num
            aug = random.choices(aug, k=args.aug_num)

            # compose augmentation
            augment = Compose(aug)
            # augment = AddBackgroundNoise(sounds_path=background_noise_path, min_snr_in_db=3.0, max_snr_in_db=30.0, p=1)

            # apply augmentation
            audio = augment(samples=audio, sample_rate=sample_rate)

        # resample to 16k
        audio = librosa.resample(y=audio, orig_sr=sample_rate, target_sr=16000)

        # if args.low_pass_filter:
        #     aug1 = LowPassFilter(p=1.0, min_cutoff_freq=300, max_cutoff_freq=7500)
        #     audio = aug1(samples=audio, sample_rate=16000)

        audiosize = audio.shape[0]
        if audiosize <= max_audio:
            shortage = max_audio - audiosize + 1
            audio = np.pad(audio, (0, shortage), 'wrap')
            audiosize = audio.shape[0]

        startframe = np.int64(np.random.rand() * (audiosize - max_audio))
        audio = audio[startframe:startframe + max_audio]

        return audio, sample_rate

    def __getitem__(self, idx):
        try:
            wav_path = self.data[idx]
            waveform, sample_rate = self.loadWAV(wav_path, self.max_frames)
        except Exception as e:
            print(e)
            print(f"Error loading {wav_path}")
            print(waveform.shape)

        return torch.tensor(waveform, dtype=torch.float32), self.labels[idx]


if __name__ == '__main__':
    # dataset
    dataset_path = {
        'real': {
            'speaker': ["deepfake_VCTK/source/",
                        "deepfake_in_the_wild/bona_fide_audio/"],
            'no_speaker': []
        },
        'fake': {
            'speaker': ["deepfake_VCTK/metavoice/",
                        "deepfake_in_the_wild/bona_fide_tts/",
                        "deepfake_in_the_wild/bona_fide_metavoice_tts/"],
            'no_speaker': []
        }
    }
    noise_dir = "/local/rcs/zz3093/data/noise/"
    root_dir = "/local/rcs/zz3093/data/"
    dataset = SpeechAugDataset(root_dir, dataset_path, noise_dir)
