import os
import os.path as osp
from os import PathLike
import numpy as np
import torch
from PIL import Image
import pickle as pkl
import cv2
import soundfile as sf
from scipy import signal
from torchvision.transforms import Compose, ToTensor
from .base_dataset import BaseDataset
import glob

def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path

def load_csv(path):
    csv_dict = {}
    with open(path, 'r') as rf:
        for line in rf:
            idx = line.find(',')
            line = line.strip()
            terms = [line[:idx], line[idx+1:]]
            assert len(terms) == 2, f'Invalid terms {line}.'
            csv_dict[terms[0]] = terms[1]
    return csv_dict

def audio2spectrogram(wav_path):
    # wav_path = '../external/VGGSound/example_audio/FwVYUHKoLtQ_000034.wav'
    # Audio
    samples, samplerate = sf.read(wav_path)
    # print(samples.shape)
    # repeat in case audio is too short
    resamples = np.tile(samples,10)[:160000]

    resamples[resamples > 1.] = 1.
    resamples[resamples < -1.] = -1.
    frequencies, times, spectrogram = signal.spectrogram(resamples, samplerate, nperseg=512, noverlap=353)
    spectrogram = np.log(spectrogram + 1e-7)

    mean = np.mean(spectrogram)
    std = np.std(spectrogram)
    spectrogram = np.divide(spectrogram-mean,std+1e-9)
    # print('a', frequencies.shape, times.shape, spectrogram.shape)
    
    spectrogram = np.tile(spectrogram.reshape((*spectrogram.shape, 1)), (1, 1, 3)).astype(np.float32)
    return spectrogram
    # print(spectrogram)
    # min_val, max_val = spectrogram.min(), spectrogram.max()
    # spectrogram = ((spectrogram - min_val) /( (max_val - min_val) + 1e-9)) * 255
    # spectrogram = np.clip(spectrogram, 0, 255)
    # return Image.fromarray(spectrogram.astype(np.uint8))
    # return spectrogram, resamples,self.classes.index(self.data2class[wav_file]),wav_file

def read_save_mp4(path, save_dir, num_frames=12):
    print(path, save_dir)
    vidcap = cv2.VideoCapture(path)

    images = []
    success,image = vidcap.read()
    while success:
        images.append(image)
        success,image = vidcap.read()
    
    chosen_idxs = np.random.choice(np.arange(len(images)), num_frames, replace=False)

    for idx in chosen_idxs:
        cv2.imwrite(os.path.join(save_dir, f"frame{idx}.png") , images[idx])     # save frame as JPEG file      
    

class VGGSoundPair(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]
    # /sdc1/hylin/datasets/VGGSound/VGGSound
    DEFAULT_TRNASFORMS = Compose([ToTensor()])
    def __init__(self,
                 root,
                 video_transforms,
                 audio_transforms,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 contrastive=False,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if video_transforms is None:
            video_transforms = self.DEFAULT_TRNASFORMS
        self.video_transforms = video_transforms
        if audio_transforms is None:
            audio_transforms = self.DEFAULT_TRNASFORMS
        self.audio_transforms = audio_transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.num_frame = 12
        self.data_infos = self.load_annotations()
        self.contrastive = contrastive
        # print(len(self.data_infos))

    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: /home/lhy/datasets/sketchy/rendered_256x256
        train_csv = './external/VGGSound/data/train.csv'
        test_csv = './external/VGGSound/data/test.csv'
        if self.test_mode:
            data_dict = load_csv(test_csv)
        else:
            data_dict = load_csv(train_csv)
        # print(data_dict)
        self.cls2idx = pkl.load(open('./external/VGGSound/data/cls2idx.pkl', 'rb'))
        self.CLASSES = list(self.cls2idx.keys())
        video_root = os.path.join(self.root, 'frames')
        audio_root = os.path.join(self.root, 'audios')
        print(video_root, audio_root)
        # just use the dataset to get imgs and set
        video_sound_pairs = []
        for video_file in data_dict:
            video_path = os.path.join(video_root, video_file)
            if os.path.exists(video_path):
                audio_path = os.path.join(audio_root, video_file[:-4] + '.wav')
                frames_path = list(sorted(glob.glob(f'{video_path}/*.png')))[:self.num_frame]
                # print(frames_path)
                video_sound_pairs.append({'frames':frames_path, 'audio':audio_path})
        data_infos = video_sound_pairs
        return data_infos

    def __getitem__(self, idx):
        frames_path, audio_path = self.data_infos[idx]['frames'], self.data_infos[idx]['audio']
        frames = [Image.open(frame_path) for frame_path in frames_path]
        # print([np.array(frame).shape for frame in frames])
        # print(self.video_transforms)
        if self.video_transforms is not None:
            frames = [self.video_transforms(frame) for frame in frames]
            if self.contrastive:
                frames = [torch.stack([frames[l][0] for l in range(len(frames))]), 
                          torch.stack([frames[l][1] for l in range(len(frames))])]
            else:
                frames = torch.stack(frames)
        # print('after transfroms', frames.size())
        audio_gram = audio2spectrogram(audio_path)
        if self.audio_transforms is not None:
            audio_gram = self.audio_transforms(audio_gram)
        # if self.audio_transforms is not None:
        #     audio_gram = self.audio_transforms(audio_gram)
        return frames, audio_gram



class VGGSoundVideo(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]
    # /sdc1/hylin/datasets/VGGSound/VGGSound
    DEFAULT_TRNASFORMS = Compose([ToTensor()])
    def __init__(self,
                 root,
                 video_transforms,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 contrastive=False,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if video_transforms is None:
            video_transforms = self.DEFAULT_TRNASFORMS
        self.video_transforms = video_transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.num_frame = 12
        self.data_infos = self.load_annotations()
        self.contrative = contrastive
        
        # print(len(self.data_infos))

    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: /home/lhy/datasets/sketchy/rendered_256x256
        train_csv = './external/VGGSound/data/train.csv'
        test_csv = './external/VGGSound/data/test.csv'
        if self.test_mode:
            data_dict = load_csv(test_csv)
        else:
            data_dict = load_csv(train_csv)
        
        self.cls2idx = pkl.load(open('./external/VGGSound/data/cls2idx.pkl', 'rb'))
        self.CLASSES = list(self.cls2idx.keys())

        video_root = os.path.join(self.root, 'frames')
        
        # just use the dataset to get imgs and set
        video_label_pairs = []
        for video_file in data_dict:
            video_path = os.path.join(video_root, video_file)
            if os.path.exists(video_path):
                frames_path = list(sorted(glob.glob(f'{video_path}/*.png')))[:self.num_frame]
                video_label_pairs.append({'frames':frames_path, 'label':self.cls2idx[data_dict[video_file]]})
        data_infos = video_label_pairs
        return data_infos

    def __getitem__(self, idx):
        frames_path, label = self.data_infos[idx]['frames'], self.data_infos[idx]['label']
        frames = [Image.open(frame_path) for frame_path in frames_path]
        if self.video_transforms is not None:
            frames = [self.video_transforms(frame) for frame in frames]
            if self.contrastive:
                frames = [torch.stack([frames[l][0] for l in range(len(frames))]), 
                          torch.stack([frames[l][1] for l in range(len(frames))])]
            else:
                frames = torch.stack(frames)
        
        if self.target_transforms is not None:
            label = self.target_transforms(label)
        return frames, label


class VGGSoundAudio(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    # CLASSES = [
    #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
    #     'horse', 'ship', 'truck'
    # ]
    # /sdc1/hylin/datasets/VGGSound/VGGSound
    DEFAULT_TRNASFORMS = Compose([ToTensor()])
    def __init__(self,
                 root,
                 audio_transforms,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 contrastive=False,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if audio_transforms is None:
            audio_transforms = self.DEFAULT_TRNASFORMS
        self.audio_transforms = audio_transforms
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
        self.contrastive = contrastive
        # print(len(self.data_infos))

    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: /home/lhy/datasets/sketchy/rendered_256x256
        train_csv = './external/VGGSound/data/train.csv'
        test_csv = './external/VGGSound/data/test.csv'
        if self.test_mode:
            data_dict = load_csv(test_csv)
        else:
            data_dict = load_csv(train_csv)
        
        self.cls2idx = pkl.load(open('./external/VGGSound/data/cls2idx.pkl', 'rb'))
        self.CLASSES = list(self.cls2idx.keys())

        audio_root = os.path.join(self.root, 'audios')
        
        # just use the dataset to get imgs and set
        audio_label_pairs = []
        for video_file in data_dict:
            audio_path = os.path.join(audio_root, video_file[:-4]+'.wav')
            if os.path.exists(audio_path):
                audio_label_pairs.append({'audio':audio_path, 'label':self.cls2idx[data_dict[video_file]]})
        data_infos = audio_label_pairs
        
        return data_infos

    def __getitem__(self, idx):
        audio_path, label = self.data_infos[idx]['audio'], self.data_infos[idx]['label']
        audio_gram = audio2spectrogram(audio_path)
        if self.audio_transforms is not None:
            audio_gram = self.audio_transforms(audio_gram)
        # if self.audio_transforms is not None:
        #     audio_gram = self.audio_transforms(audio_gram)
        if self.target_transforms is not None:
            label = self.target_transforms(label)
        return audio_gram, label


