import argparse
import os
import cv2
import json
import torch
import csv
import numpy as np
import torchaudio
import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import time
from PIL import Image
import glob
import sys
from scipy import signal
import random
import soundfile as sf
import librosa

from torchaudio.transforms import *


def read_vggsound_a(batchsize):
    parser = argparse.ArgumentParser()
    parser.csv_path = r'YourPath'
    parser.data_path = r'YourPath'
    parser.mode = 'train'
    train = GetAudioVideoDataset(args=parser)
    parser = argparse.ArgumentParser()
    parser.csv_path = r'YourPath'
    parser.data_path = r'YourPath'
    parser.mode = 'test'
    test = GetAudioVideoDataset(args=parser)
    # train = datasets.ImageNet(root=data_dir, split='train', transform=train_transform)
    # test = datasets.ImageNet(root=data_dir, split='val', transform=normalize_transform)
    train_dataloader = DataLoader(train, batch_size=batchsize, shuffle=True, pin_memory=False)
    test_dataloader = DataLoader(test, batch_size=batchsize, shuffle=False, pin_memory=False)
    return train_dataloader, test_dataloader  # , label_train, label_test


def read_vggsound_v(batchsize):
    parser = argparse.ArgumentParser()
    parser.csv_path = r'YourPath'
    parser.data_path = r'YourPath'
    parser.mode = 'train'
    train = GetVideoAudioDataset(args=parser)
    parser = argparse.ArgumentParser()
    parser.csv_path = r'YourPath'
    parser.data_path = r'YourPath'
    parser.mode = 'test'
    test = GetVideoAudioDataset(args=parser)
    # train = datasets.ImageNet(root=data_dir, split='train', transform=train_transform)
    # test = datasets.ImageNet(root=data_dir, split='val', transform=normalize_transform)
    train_dataloader = DataLoader(train, batch_size=batchsize, shuffle=True, pin_memory=False)
    test_dataloader = DataLoader(test, batch_size=batchsize, shuffle=False, pin_memory=False)
    return train_dataloader, test_dataloader  # , label_train, label_test


def create_new_filename(old_filename):
    base = old_filename[:-4]
    video_id = base[:-7]
    start_time = int(base[-6:])
    end_time = start_time + 10

    new_filename = f"v{video_id}_{start_time}_{end_time}_out.mkv"
    return new_filename


class GetAudioVideoDataset(Dataset):

    def __init__(self, args, mode='train', transforms=None):
        # data2path = {}
        classes = []
        # classes_ = []
        # data = []
        # path = []
        self.video_files = []
        # self.data2class = {}
        # data2class = {}

        with open(args.csv_path + 'stat.csv') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])

        with open(args.csv_path + args.mode + '.csv') as f:
            csv_reader = csv.reader(f)
            for item in csv_reader:
                cur_path = args.data_path + args.mode + '/' + item[1].replace(' ', '_') + '/' + item[0][:-3] + 'mp4'
                if item[1] in classes and os.path.exists(cur_path):
                    # data.append(item[0])
                    self.video_files.append([cur_path, item[1]])
                    # self.data2class[cur_path] = item[1]

        # self.audio_path = args.data_path + args.mode
        self.mode = mode
        self.transforms = transforms
        self.classes = sorted(classes)

        # initialize audio transform
        self._init_atransform()
        #  Retrieve list of audio and video files

        # for item in data:
        #    self.video_files.append(item)
        print('# of audio files = %d ' % len(self.video_files))
        print('# of classes = %d' % len(self.classes))
        print(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        wav_data = self.video_files[idx]
        wav_data[0] = wav_data[0].replace('.mp4', '.wav')
        # Audio
        samples, samplerate = librosa.load(wav_data[0])
        # samples, samplerate = librosa.load(wav_data[0], sr=16000)
        samples = samples[:144000]
        samples = samples.reshape(9, 16000)
        # repeat in case audio is too short
        return samples, self.classes.index(wav_data[1])
        # return spectrogram, resamples, self.classes.index(wav_data[1]), wav_data[0]


class GetVideoAudioDataset(Dataset):

    def __init__(self, args, mode='train', transforms=None):
        # data2path = {}
        classes = []
        # classes_ = []
        # data = []
        # path = []
        self.video_files = []
        # self.data2class = {}
        # data2class = {}

        with open(args.csv_path + 'stat.csv') as f:
            csv_reader = csv.reader(f)
            for row in csv_reader:
                classes.append(row[0])

        with open(args.csv_path + args.mode + '.csv') as f:
            csv_reader = csv.reader(f)
            for item in csv_reader:
                cur_path = args.data_path + args.mode + '/' + item[1].replace(' ', '_') + '/' + item[0][:-3] + 'mp4'
                if item[1] in classes and os.path.exists(cur_path.replace('.mp4', '_prep.mp4')):
                    # data.append(item[0])
                    self.video_files.append([cur_path, item[1]])
                    # self.data2class[cur_path] = item[1]

        # self.audio_path = args.data_path + args.mode
        self.mode = mode
        self.transforms = transforms
        self.classes = sorted(classes)

        # initialize audio transform
        self._init_atransform()
        #  Retrieve list of audio and video files

        # for item in data:
        #    self.video_files.append(item)
        print('# of audio files = %d ' % len(self.video_files))
        print('# of classes = %d' % len(self.classes))
        print(self.classes)

    def _init_atransform(self):
        self.aid_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx, num_frames=45, frame_size=(128, 96)):
        wav_data = self.video_files[idx]
        frames = []
        cap = cv2.VideoCapture(wav_data[0].replace('.mp4', '_prep.mp4'))
        if not cap.isOpened():
            print("Error: Unable to open video file")
            print(wav_data[0].replace('.mp4', '_prep.mp4'))
            return None
        frame = []
        for i in range(num_frames):
            if i % 3 != 0:
                continue
            ret, temp_frame = cap.read()
            if not ret:
                frames.append(frame)
                continue
            frame = cv2.resize(temp_frame, frame_size)
            frames.append(frame)
        cap.release()
        frames = np.array(frames)
        return torch.FloatTensor(frames), self.classes.index(wav_data[1])


def main():
    parser = argparse.ArgumentParser()
    parser.csv_path = r'YourPath'
    parser.data_path = r'YourPath'
    parser.mode = 'train'
    dataset = GetVideoAudioDataset(args=parser)
    for i in tqdm.tqdm(range(len(dataset))):
        tensor, label = dataset.__getitem__(i)
        if tensor.shape != torch.Size([15, 96, 128, 3]):
            print(tensor.shape)
            print(i)
        # print(i, label)


if __name__ == "__main__":
    main()
