import xlrd
import os 
import moviepy.editor as mpy

import subprocess
import librosa

import re
from decimal import Decimal

import wave
import cv2

import numpy as np

import pickle

import random
import pandas as pd
import h5py

import  argparse 

from collections import Counter

from PIL import Image

import math

from facenet_pytorch import MTCNN, InceptionResnetV1
import scipy.misc
from skimage.restoration import (denoise_tv_chambolle, denoise_bilateral,
                                 denoise_wavelet, estimate_sigma)
from skimage import data, img_as_float
from skimage.util import random_noise

import pickle


import scipy.misc

import librosa
import torchvision
from skimage import io

import cv2
import numpy as np
from PIL import Image
import imageio

parser = argparse.ArgumentParser(description='Dataset Maker')
parser.add_argument('--dataset', type=str, default='enterface', help='dataset type')
parser.add_argument('--win', type=float, default=0.5, help='seconds')
parser.add_argument('--num_data_augu', type=int, default=30, help='data augumentation time')
args = parser.parse_args()


def wgn(x, snr):
    snr = 10**(snr/10.0)
    xpower = np.sum(x**2)/len(x)
    npower = xpower / snr
    return np.random.randn(len(x)) * np.sqrt(npower)

    



def video_to_frames(win, fps, num_all_frames, read_path, vidcap,  num_frames, image_path, audio_path):

    mtcnn_2 = MTCNN()


    middle_frame= int(num_frames/2)

    while True:

        if num_all_frames>num_frames:
            start=random.randint(0,num_all_frames-num_frames)
            middle=start+middle_frame
        else:
            start=0
            middle=int(num_all_frames/2)


        vidcap.set(cv2.CAP_PROP_POS_FRAMES, middle)
        success,img=vidcap.read()  
        # print(img)
        assert isinstance(img, np.ndarray) 
        flag= (img==0).all()

        if not flag:

            cv2.imwrite(os.path.join(image_path, 'middle.jpg'), img)

            img_new = cv2.imread(os.path.join(image_path, 'middle.jpg'))
            img_new = cv2.cvtColor(img_new, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(img_new)
            image.save(os.path.join(image_path, 'middle.jpg'))


            img_cropped_2 = mtcnn_2(image, save_path=os.path.join(image_path, 'middle_cropped_160.jpg'))


            start_time = start/fps
            y_audio, sr = librosa.load(read_path, sr=None, offset=start_time, duration=win)

            if len(y_audio)< int(sr*win):
                gap=int(sr*win)-len(y_audio)
                right_gap=int(gap/2)
                left_gap=gap-right_gap

                y_audio=np.lib.pad(y_audio, (left_gap,right_gap), 'constant', constant_values=(0.0, 0.0))

            assert len(y_audio)==int(sr*win)


            melspectrogram = librosa.feature.melspectrogram(y=y_audio, sr=48000, n_fft=1024,hop_length=256,n_mels=94)
            log_melspectrogram = librosa.amplitude_to_db(melspectrogram)
            log_audio_image_path = os.path.join(audio_path,'audio_94*94_log.jpg')
            imageio.imwrite(log_audio_image_path, log_melspectrogram)


            with h5py.File(os.path.join(audio_path,'data_audio.h5'), 'w') as hf:
                hf.create_dataset("data_audio", data=y_audio)

            break

    

def enterface_data(enterface_path,num_data_augu,num_frames,win,save_path):
    count=0


    for subject in os.listdir(enterface_path):
        if os.path.isdir(os.path.join(enterface_path,subject)):
            sample_count=0
            for emotion in os.listdir(os.path.join(enterface_path,subject)):
                if os.path.isdir(os.path.join(enterface_path,subject,emotion)):
                    if subject!='subject 6':
                        for sentence in os.listdir(os.path.join(enterface_path,subject,emotion)):
                            if os.path.isdir(os.path.join(enterface_path,subject,emotion,sentence)):
                                for avi in os.listdir(os.path.join(enterface_path,subject,emotion,sentence)):
                                    if avi.endswith('avi'):
                                        count=count+1
                                        print(count)
                                        sample_count=sample_count+1
                                        label=0
                                        split=avi.split('_')
                                        if split[1]=='an':
                                            label=0
                                        if split[1]=='di':
                                            label=1
                                        if split[1]=='fe':
                                            label=2
                                        if split[1]=='ha':
                                            label=3
                                        if split[1]=='sa':
                                            label=4
                                        if split[1]=='su':
                                            label=5
                                        if split[1]=='3':
                                            if split[2]=='an':
                                                label=0
                                            if split[2]=='di':
                                                label=1
                                            if split[2]=='fe':
                                                label=2
                                            if split[2]=='ha':
                                                label=3
                                            if split[2]=='sa':
                                                label=4
                                            if split[2]=='su':
                                                label=5


                                        read_path= os.path.join(enterface_path,subject,emotion,sentence,avi)
                                        vidcap = cv2.VideoCapture(read_path)
                                        num_all_frames=int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))-1
                                        
                                        fps = vidcap.get(cv2.CAP_PROP_FPS)
                                        

                                        for i in range(num_data_augu):
                                            final_save_path=os.path.join(save_path, 'subject_'+ subject[8:], 'sample_'+ str(sample_count)+'_label_'+ str(label), 'augu_'+str(i))

                                            if not os.path.exists(os.path.join(final_save_path, 'audio')):
                                                os.makedirs(os.path.join(final_save_path, 'audio'))

                                            if not os.path.exists(os.path.join(final_save_path, 'image')):
                                                os.makedirs(os.path.join(final_save_path, 'image'))

                                            video_to_frames(win, fps, num_all_frames, read_path, vidcap, num_frames, os.path.join(final_save_path, 'image'), os.path.join(final_save_path, 'audio'))



                    else:
                        for avi in os.listdir(os.path.join(enterface_path,subject,emotion)):
                            if avi.endswith('avi'):
                                count=count+1
                                print(count)
                                sample_count=sample_count+1
                                label=0
                                split=avi.split('_')
                                if split[1]=='an':
                                    label=0
                                if split[1]=='di':
                                    label=1
                                if split[1]=='fe':
                                    label=2
                                if split[1]=='ha':
                                    label=3
                                if split[1]=='sa':
                                    label=4
                                if split[1]=='su':
                                    label=5
                                if split[1]=='3':
                                    if split[2]=='an':
                                        label=0
                                    if split[2]=='di':
                                        label=1
                                    if split[2]=='fe':
                                        label=2
                                    if split[2]=='ha':
                                        label=3
                                    if split[2]=='sa':
                                        label=4
                                    if split[2]=='su':
                                        label=5


                                read_path=os.path.join(enterface_path,subject,emotion,avi)
                                vidcap = cv2.VideoCapture(read_path)
                                num_all_frames=int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))-1
                                
                                fps = vidcap.get(cv2.CAP_PROP_FPS)


                                for i in range(num_data_augu):


                                    final_save_path=os.path.join(save_path, 'subject_'+ subject[8:], 'sample_'+ str(sample_count)+'_label_'+ str(label), 'augu_'+str(i))

                                    if not os.path.exists(os.path.join(final_save_path, 'audio')):
                                        os.makedirs(os.path.join(final_save_path, 'audio'))

                                    if not os.path.exists(os.path.join(final_save_path, 'image')):
                                        os.makedirs(os.path.join(final_save_path, 'image'))


                                    video_to_frames(win, fps, num_all_frames, read_path, vidcap, num_frames, os.path.join(final_save_path, 'image'), os.path.join(final_save_path, 'audio'))



if __name__ == '__main__':


    if args.dataset=='enterface':
        # Path holds the enterface original dataset
        path="./enterface/"
        num_subject=44
        num_samples=1293
        fps=25
        sampling_rate=48000
        num_frames=int(math.ceil(args.win*fps))

        win=args.win
        save_path="./processed_enterface/"
        enterface_data(path, args.num_data_augu, num_frames, win, save_path)







        






