import xlrd
import os 
import moviepy.editor as mpy

import subprocess
import librosa

import re
from decimal import Decimal

import wave
import cv2

import numpy as np

import pickle

import random
import pandas as pd
import h5py

import  argparse 

from collections import Counter

from PIL import Image

import math

from facenet_pytorch import MTCNN, InceptionResnetV1
import scipy.misc
from skimage.restoration import (denoise_tv_chambolle, denoise_bilateral,
                                 denoise_wavelet, estimate_sigma)
from skimage import data, img_as_float
from skimage.util import random_noise

import pickle


import scipy.misc

import librosa
import torchvision
from skimage import io

import cv2
import numpy as np
from PIL import Image
import imageio

parser = argparse.ArgumentParser(description='Dataset Maker')
parser.add_argument('--dataset', type=str, default='RAVDESS', help='dataset type')
parser.add_argument('--win', type=float, default=0.5, help='seconds')
parser.add_argument('--num_data_augu', type=int, default=30, help='data augumentation time')
args = parser.parse_args()


def wgn(x, snr):
    snr = 10**(snr/10.0)
    xpower = np.sum(x**2)/len(x)
    npower = xpower / snr
    return np.random.randn(len(x)) * np.sqrt(npower)

    



def video_to_frames(win, fps, num_all_frames, read_path, vidcap,  num_frames, image_path, audio_path):

    mtcnn_2 = MTCNN()


    middle_frame= int(num_frames/2)

    while True:

        if num_all_frames>num_frames:
            start=random.randint(0,num_all_frames-num_frames)
            middle=start+middle_frame
        else:
            start=0
            middle=int(num_all_frames/2)


        vidcap.set(cv2.CAP_PROP_POS_FRAMES, middle)
        success,img=vidcap.read()  
        # print(img)
        assert isinstance(img, np.ndarray) 
        flag= (img==0).all()

        if not flag:

            cv2.imwrite(os.path.join(image_path, 'middle.jpg'), img)

            img_new = cv2.imread(os.path.join(image_path, 'middle.jpg'))
            img_new = cv2.cvtColor(img_new, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(img_new)
            image.save(os.path.join(image_path, 'middle.jpg'))


            img_cropped_2 = mtcnn_2(image, save_path=os.path.join(image_path, 'middle_cropped_160.jpg'))


            start_time = start/fps
            y_audio, sr = librosa.load(read_path, sr=None, offset=start_time, duration=win)

            if len(y_audio)< int(sr*win):
                gap=int(sr*win)-len(y_audio)
                right_gap=int(gap/2)
                left_gap=gap-right_gap

                y_audio=np.lib.pad(y_audio, (left_gap,right_gap), 'constant', constant_values=(0.0, 0.0))

            assert len(y_audio)==int(sr*win)


            melspectrogram = librosa.feature.melspectrogram(y=y_audio, sr=48000, n_fft=1024,hop_length=256,n_mels=94)
            log_melspectrogram = librosa.amplitude_to_db(melspectrogram)
            log_audio_image_path = os.path.join(audio_path,'audio_94*94_log.jpg')
            imageio.imwrite(log_audio_image_path, log_melspectrogram)


            with h5py.File(os.path.join(audio_path,'data_audio.h5'), 'w') as hf:
                hf.create_dataset("data_audio", data=y_audio)

            break

    

def RAVDESS_speech_data(ravdess_path,num_data_augu,num_frames,win,save_path):
    count=0


    for actor in os.listdir(ravdess_path):
        if os.path.isdir(os.path.join(ravdess_path,actor)):
            sample_count=0
            for mp4 in os.listdir(os.path.join(ravdess_path,actor)):
                split=mp4.split('-')
                if mp4.endswith('mp4') and split[0]=='01' and split[1]=='01':

                    count=count+1
                    print(count)
                    sample_count=sample_count+1
                    label=0

                    if split[2]=='01':
                        label=0
                    if split[2]=='02':
                        label=1
                    if split[2]=='03':
                        label=2
                    if split[2]=='04':
                        label=3
                    if split[2]=='05':
                        label=4
                    if split[2]=='06':
                        label=5
                    if split[2]=='07':
                        label=6
                    if split[2]=='08':
                        label=7



                    read_path=os.path.join(ravdess_path,actor,mp4)
                    print(read_path)
                    vidcap = cv2.VideoCapture(read_path)
                    num_all_frames=int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))-1
                    
                    fps = vidcap.get(cv2.CAP_PROP_FPS)



                    for i in range(num_data_augu):

                        final_save_path=os.path.join(save_path, 'subject_'+ str(int(actor[6:])), 'sample_'+ str(sample_count)+'_label_'+ str(label), 'augu_'+str(i))


                        if not os.path.exists(os.path.join(final_save_path, 'audio')):
                            os.makedirs(os.path.join(final_save_path, 'audio'))

                        if not os.path.exists(os.path.join(final_save_path, 'image')):
                            os.makedirs(os.path.join(final_save_path, 'image'))

                        video_to_frames(win, fps, num_all_frames, read_path, vidcap, num_frames, os.path.join(final_save_path, 'image'), os.path.join(final_save_path, 'audio'))



if __name__ == '__main__':

    if args.dataset=='RAVDESS':
        # Path holds the RAVDESS original dataset
        path="./RAVDESS/"
        num_subject=24
        num_samples=1440
        fps=29.97002997002997
        sampling_rate=48000
        num_frames=int(math.ceil(args.win*fps))
        win=args.win
        save_path="./processed_RAVDESS/"
        RAVDESS_speech_data(path, args.num_data_augu, num_frames, win, save_path)







        






