import ffmpeg
from joblib import Parallel, delayed
from python_speech_features import logfbank
import random
import torch.utils.data
from .vid_utils import load_audio
import numpy as np
import pandas as pd
import os
import torch
from torch.utils.data import Dataset


ROOT_DIR = {
    'vggsound': './data/vggsound',
}

MODE_DIR = {
    'vggsound': {
        'train': 'train',
        'val': 'test'
    },
}


#####################

def get_spec(
    wav, 
    fr_sec, 
    num_sec=1, 
    sample_rate=48000, 
    aug_audio=[], 
    aud_spec_type=1, 
    use_volume_jittering=False,
    use_temporal_jittering=False,
    z_normalize=False
):

    # Temporal  jittering - get audio with 0.5 s of video clip
    if use_temporal_jittering:
        fr_sec = fr_sec + np.random.uniform(-0.5, 0.5)

    # Get to and from indices num seconds of audio
    fr_aud = int(np.round(fr_sec * sample_rate))
    to_aud = int(np.round(fr_sec * sample_rate) + sample_rate * num_sec)

    # Check to ensure that we never get clip longer than wav
    if fr_aud + (to_aud - fr_aud) > len(wav):
        fr_aud = len(wav) - sample_rate * num_sec
        to_aud = len(wav)

    # Get subset of wav
    wav = wav[fr_aud: to_aud]

    # Volume  jittering - scale volume by factor in range (0.9, 1.1)
    if use_volume_jittering:
        wav = wav * np.random.uniform(0.9, 1.1)

    # Convert to log filterbank
    if aud_spec_type == 1:
        spec = logfbank(
            wav, 
            sample_rate,
            winlen=0.02,
            winstep=0.01,
            nfilt=40,
            nfft=1024
        )
    else:
        spec = logfbank(
            wav, 
            sample_rate,
            winlen=0.02,
            winstep=0.005 if num_sec==1 else 0.01,
            nfilt=257,
            nfft=1024
        )

    # Convert to 32-bit float and expand dim
    spec = spec.astype('float32') # (T, F) (99, 40) / (257 / 199)
    spec = spec.T # (F, T) (40, 99) / (257 / 199)
    spec = np.expand_dims(spec, axis=0)
    spec = torch.as_tensor(spec)

    if z_normalize:
        spec = (spec - 1.93) / 17.89 # transforms.Normalize([1.93], [17.89])(spec)

    return spec


def load_audio(
    vid_path, 
    fr_sec=None, 
    num_sec=1, 
    sample_rate=48000, 
    aug_audio=[], 
    aud_spec_type=1, 
    use_volume_jittering=False,
    use_temporal_jittering=False,
    z_normalize=False
):

    # Load wav file @ sample_rate
    out, _ = (
        ffmpeg
        .input(vid_path)
        .output('-', format='s16le', acodec='pcm_s16le', ac=1, ar=sample_rate)
        .run(quiet=True)
    )
    wav = (
        np
        .frombuffer(out, np.int16)
    )

    # repeat in case audio is too short
    wav = np.tile(wav, 10)[:sample_rate * 10]

    if fr_sec is None:
        length = int(len(wav) / sample_rate)
        fr_sec = random.randint(0, length - num_sec)

    # Get spectogram
    spec = get_spec(
        wav, 
        fr_sec, 
        num_sec=num_sec, 
        sample_rate=sample_rate, 
        aug_audio=aug_audio, 
        aud_spec_type=aud_spec_type, 
        use_volume_jittering=use_volume_jittering, 
        use_temporal_jittering=use_temporal_jittering,
        z_normalize=z_normalize
    )
    return spec




#######################


def valid_video(vid_idx, vid_path):
    try:
        probe = ffmpeg.probe(vid_path)
        video_stream = next((
            stream for stream in probe['streams'] if stream['codec_type'] == 'video'), 
            None
        )
        audio_stream = next((
            stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), 
            None
        )
        if audio_stream and video_stream and float(video_stream['duration']) > 1.1 and float(audio_stream['duration']) > 1.1:
            print(f"{vid_idx}: True", flush=True)
            return True
        else:
            print(f"{vid_idx}: False (duration short/ no audio)", flush=True)
            return False
    except:
        print(f"{vid_idx}: False", flush=True)
        return False


def filter_videos(vid_paths):
    all_indices = Parallel(n_jobs=30)(delayed(valid_video)(vid_idx, vid_paths[vid_idx]) for vid_idx in range(len(vid_paths)))
    valid_indices = [i for i, val in enumerate(all_indices) if val]
    return valid_indices


class GetAudioDataset(Dataset):

    def __init__(self, csvpath, datapath, mode='train', transforms=None):
        data2path = {}
        classes = []
        classes_ = []
        data = []
        data2class = {}

        with open(csvpath + 'stat.csv') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])

        self.mode = mode
        t_str = 'train.csv' if self.mode == 'train' else 'test.csv'
        with open(csvpath + t_str) as f:
            csv_reader = csv.reader(f)
            for item in csv_reader:
                if item[1] in classes:
                    data.append(item[0])
                    data2class[item[0]] = item[1]

        self.audio_path = datapath

        self.transforms = transforms
        self.classes = sorted(classes)
        self.data2class = data2class

        # initialize audio transform
        self._init_atransform()
        #  Retrieve list of audio and video files
        self.video_files = []

        for item in data:
            self.video_files.append(item)
        print('# of audio files = %d ' % len(self.video_files))
        print('# of classes = %d' % len(self.classes))

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        wav_file = self.video_files[idx]
        # Audio
        samples, samplerate = sf.read(self.audio_path + wav_file[:-3] + 'wav')
        # repeat in case audio is too short
        resamples = np.tile(samples, 10)[:samplerate*10]
        label  = self.classes.index(self.data2class[wav_file])
        length = int(len(resamples) / samplerate)
        # Get spectogram

        if self.mode == 'train':
            fr_sec = random.random()*(length - 3)
            spectrogram = get_spec(
                resamples,
                fr_sec,
                num_sec=3,
                sample_rate=samplerate,
                aud_spec_type=2,
                use_volume_jittering=True,
                use_temporal_jittering=False,
                z_normalize=True
            )
        else:
            spectrogram = get_spec(
                    resamples,
                    0,
                    num_sec=10,
                    sample_rate=samplerate,
                    aud_spec_type=2,
                    use_volume_jittering=False,
                    use_temporal_jittering=False,
                    z_normalize=True
                )
        return spectrogram, label, idx



class ESC_DCASE(Dataset):
    def __init__(
        self,
        root_dir='./data/ESC_50/audio',
        dataset='esc50',
        metadata_path='./data/ESC_50/esc50.csv',
        val_fold=1,
        mode='train',
        num_samples=10,
        seconds=1,
        random_starts=False,
        nfilter=80
    ):

        # Save input
        self.root_dir = root_dir
        self.nfilter = nfilter

        # Load metadata
        df = pd.read_csv(metadata_path)

        # Get right subset
        if mode not in ['train', 'val']:
            assert("'train' and 'val' are only modes supported")

        # Get train fold
        if mode == 'train':
            df_fold = df[df.fold != val_fold]
            self.filenames = list(df_fold['filename'])
            self.labels = list(df_fold['target'])
        elif mode == 'val':
            df_fold = df[df.fold == val_fold]
            self.filenames = list(df_fold['filename'])
            self.labels = list(df_fold['target'])

        self.dataset = []
        self.seconds = seconds
        self.last_sec = 4.0 if seconds == 1 else 3.0
        if random_starts:
            for i, filename in enumerate(self.filenames):
                count = 0
                for fr in np.random.uniform(0, self.last_sec, num_samples):
                    label = self.labels[i]
                    self.dataset.append((filename, fr, label, i))
                    count += 1
        else:
            for i, filename in enumerate(self.filenames):
                count = 0
                for fr in np.linspace(0, self.last_sec, num_samples):
                    label = self.labels[i]
                    self.dataset.append((filename, fr, label, i))
                    count += 1

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        filename, fr, label, aud_idx = self.dataset[idx]
        filepath = os.path.join(self.root_dir, filename)
        spectogram = load_audio(filepath, fr, num_sec=2, sample_rate=24000, aud_spec_type=2, z_normalize=True)
        return spectogram, label, aud_idx

