import json
import os

import librosa
import numpy as np
import torch
from imageio import imread
from torch.utils.data import Dataset
from nnAudio.Spectrogram import STFT


class AudioVideoDataset(Dataset):
    def __init__(self, path, num_frames):
        with open("label.json") as f:
            self.label_list = json.load(f)
        self.num_frames = num_frames
        self.audio_list = []
        self.video_list = []
        self.stft = STFT(n_fft=1024, win_length=1024, hop_length=512, sr=16000, output_format='Magnitude')

        videos = sorted(os.listdir(os.path.join(path, 'video_jpg')))
        for v in videos:
            if v.endswith('jpg'):
                self.audio_list.append(os.path.join(path, 'audio_npy', v[:-3]+'npy'))
                self.video_list.append(os.path.join(path, 'video_jpg', v))

    def wave_normalize(self, wav):
        norm = torch.max(torch.abs(wav)) * 1.1
        wav = wav/norm
        return wav

    def video_normalize(self, video):
        mean = torch.tensor([0.485, 0.456, 0.406]).view((1,3,1,1))
        std = torch.tensor([0.229, 0.224, 0.225]).view((1,3,1,1))
        return (video/255.-mean)/std

    def __len__(self):
        return len(self.video_list)
    def __getitem__(self, item):

        video = imread(self.video_list[item]).reshape(self.num_frames, 256, 256, 3)
        wave = np.load(self.audio_list[item]).squeeze(0)
        
        video = self.video_normalize(torch.from_numpy(video).permute(0,3,1,2))
        wave = self.wave_normalize(torch.from_numpy(wave))
        stft = self.stft(wave)

        return {'video': video, #Seq, C, H, W
                'audio': stft,
                'label': self.label_list[self.video_list[item].split("/")[-1].split("_")[-2]]}



class AudioVideoTestDataset(Dataset):
    def __init__(self, path, num_frames, sr=16000):
        self.sr = sr
        with open("label.json") as f:
            self.label_list = json.load(f)
        self.num_frames = num_frames
        self.audio_list =[]
        self.video_list=[]
        self.stft = STFT(n_fft=1024, win_length=1024, hop_length=512, sr=16000, output_format='Magnitude')

        videos = sorted(os.listdir(os.path.join(path, 'video_jpg')))
        for v in videos:
            if v.endswith('jpg'):
                self.audio_list.append(os.path.join(path, 'audio', v[:-3]+'wav'))
                self.video_list.append(os.path.join(path, 'video_jpg', v))

    def wave_normalize(self, wav):
        norm = torch.max(torch.abs(wav)) * 1.1
        wav = wav/norm
        return wav

    def video_normalize(self, video):
        mean = torch.tensor([0.485, 0.456, 0.406]).view((1,3,1,1))
        std = torch.tensor([0.229, 0.224, 0.225]).view((1,3,1,1))
        return (video/255.-mean)/std

    def __len__(self):
        return len(self.video_list)
    def __getitem__(self, item):

        video = imread(self.video_list[item]).reshape(self.num_frames,256,256,3)
        wave, _ = librosa.core.load(self.audio_list[item], sr=self.sr, mono=True)
        
        video = self.video_normalize(torch.from_numpy(video).permute(0,3,1,2))
        wave = self.wave_normalize(torch.from_numpy(wave))
        stft = self.stft(wave)

        return {'video': video, #Seq, C, H, W
                'audio': stft,
                'label': self.label_list[self.video_list[item].split("/")[-1].split("_")[-2]]}
