import subprocess
import argparse
import h5py
import numpy as np
import os
import shutil
from pathlib import Path
from joblib import Parallel, delayed

def audio_process(video_file_path, dst_root_path, ext, sample_rate=32000, high_f=50, low_f=14000):

    name = video_file_path.stem
    dst_root_path.mkdir(exist_ok=True)
    wav_path = dst_root_path / f'{name}.wav'

    if os.path.isfile(wav_path) is False:
        # if the file does not exist
        if ext != video_file_path.suffix:
            print(ext)
            print(video_file_path.suffix)
            return

        ################################################################################################################
        af_param = f'asetrate={sample_rate},highpass=f={high_f},lowpass=f={low_f}'
        ffmpeg_cmd = ['ffmpeg', '-i', str(video_file_path), '-af', af_param, wav_path]
        subprocess.run(ffmpeg_cmd)

        # hdf5_path = dst_dir_path.parent / f'{dst_dir_path.name}.hdf5'
        # try:
        #     with h5py.File(hdf5_path, 'w') as f:
        #         dtype = h5py.special_dtype(vlen='uint8')
        #         video = f.create_dataset('video',
        #                                 (len(list(dst_dir_path.glob('*.jpg'))),),
        #                                 dtype=dtype)
        # except OSError as exc:
        #     if 'errno = 36' in exc.args[0]:
        #         hdf5_path = dst_dir_path.parent / f'{dst_dir_path.name[:250]}.hdf5'
        #         with h5py.File(hdf5_path, 'w') as f:
        #             dtype = h5py.special_dtype(vlen='uint8')
        #             video = f.create_dataset('video',
        #                                     (len(list(dst_dir_path.glob('*.jpg'))),),
        #                                     dtype=dtype)
        #     else:
        #         raise

        # for i, file_path in enumerate(sorted(dst_dir_path.glob('*.jpg'))):
        #     with file_path.open('rb') as f:
        #         data = f.read()
        #     try:
        #     with h5py.File(hdf5_path, 'r+') as f:
        #         video = f['video']
        #         video[i] = np.frombuffer(data, dtype='uint8')
        #     except:
        #     print('could not write')

        # else:
        #     shutil.rmtree(dst_dir_path)


def class_process(class_dir_path, dst_root_path, ext, sample_rate, high_f, low_f):

    if not class_dir_path.is_dir():
        return
    dst_class_path = dst_root_path / class_dir_path.name
    dst_class_path.mkdir(exist_ok=True)

    for audio_file_path in sorted(class_dir_path.iterdir()):
        print(audio_file_path)
        # import IPython; IPython.embed(); exit(0)
        audio_process(audio_file_path, dst_class_path, ext, sample_rate, high_f, low_f)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_path',
                        default=None,
                        type=Path,
                        help='Directory path of audios')
    parser.add_argument('--dst_path',
                        default=None,
                        type=Path,
                        help='Directory path of wav audios')
    parser.add_argument('--dataset',
                        default='',
                        type=str,
                        help='Dataset name (kinetics | mit | ucf101 | hmdb51 | activitynet)')
    parser.add_argument('--n_jobs',
                        default=1,
                        type=int,
                        help='Number of parallel jobs')
    parser.add_argument('--split',
                        default=0,
                        type=int,
                        help='Number of parallel jobs')
    parser.add_argument('--total_split',
                        default=1,
                        type=int,
                        help='Number of total jobs')
    parser.add_argument('--sample_rate',
                        default=32000,
                        type=int,
                        help=('Sample rate of audio file'
                              '-1 means original sample rates.'))
    parser.add_argument('--low_pass_f',
                        default=14000,
                        type=int,
                        help=('Low pass frequency threshold of audio file'))
    parser.add_argument('--high_pass_f',
                        default=50,
                        type=int,
                        help=('High pass frequency threshold of audio file'))
    args = parser.parse_args()

    if args.dataset in ['kinetics', 'mit', 'activitynet', 'vggsound']:
        ext = '.mp4'
    else:
        ext = '.avi'

    class_dir_paths = [x for x in sorted(args.dir_path.iterdir())]
    k = args.split
    interval = int(len(class_dir_paths) / args.total_split)
    if k + 1 == args.total_split:
        class_dir_paths = class_dir_paths[k * interval::]
    else:
        class_dir_paths = class_dir_paths[k * interval:(k + 1) * interval]

    status_list = Parallel(n_jobs=args.n_jobs, backend='threading')(
        delayed(class_process)(class_dir_path, args.dst_path, ext, args.sample_rate, args.high_pass_f, args.low_pass_f)
        for class_dir_path in class_dir_paths)