import os
import os.path as osp
from os import PathLike
import numpy as np
import torch
from PIL import Image
import cv2
import soundfile as sf
from scipy import signal
from torchvision.transforms import Compose, ToTensor
# from .base_dataset import BaseDataset
import glob

def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path

def load_csv(path):
    csv_dict = {}
    with open(path, 'r') as rf:
        for line in rf:
            idx = line.find(',')
            line = line.strip()
            terms = [line[:idx], line[idx+1:]]
            assert len(terms) == 2, f'Invalid terms {line}.'
            csv_dict[terms[0]] = terms[1]
    return csv_dict

def audio2spectrogram(wav_path):
    # wav_path = '../external/VGGSound/example_audio/FwVYUHKoLtQ_000034.wav'
    # Audio
    samples, samplerate = sf.read(wav_path)
    # print(samples.shape)
    # repeat in case audio is too short
    resamples = np.tile(samples,10)[:160000]

    resamples[resamples > 1.] = 1.
    resamples[resamples < -1.] = -1.
    frequencies, times, spectrogram = signal.spectrogram(resamples, samplerate, nperseg=512, noverlap=353)
    spectrogram = np.log(spectrogram + 1e-7)

    mean = np.mean(spectrogram)
    std = np.std(spectrogram)
    spectrogram = np.divide(spectrogram-mean,std+1e-9)
    # print('a', frequencies.shape, times.shape, spectrogram.shape)
    
    spectrogram = np.tile(spectrogram.reshape((*spectrogram.shape, 1)), (1, 1, 3))
    # print(spectrogram)
    min_val, max_val = spectrogram.min(), spectrogram.max()
    spectrogram = ((spectrogram - min_val) /( (max_val - min_val) + 1e-9)) * 255
    spectrogram = np.clip(spectrogram, 0, 255)
    return Image.fromarray(spectrogram.astype(np.uint8))
    # return spectrogram, resamples,self.classes.index(self.data2class[wav_file]),wav_file

def read_save_mp4(path, save_dir, num_frames=12):
    print(path, save_dir)
    vidcap = cv2.VideoCapture(path)

    images = []
    success,image = vidcap.read()
    while success:
        images.append(image)
        success,image = vidcap.read()
    
    chosen_idxs = np.random.choice(np.arange(len(images)), num_frames, replace=False)

    for idx in chosen_idxs:
        cv2.imwrite(os.path.join(save_dir, f"frame{idx}.png") , images[idx])     # save frame as JPEG file      
    

def data_processing():
    train_csv = '../external/VGGSound/data/train.csv'
    test_csv = '../external/VGGSound/data/test.csv'
    train_dict = load_csv(train_csv)
    test_dict = load_csv(test_csv)
    
    data_dict = train_dict
    CLASSES = []
    for file in data_dict:
        cls_tag = data_dict[file]
        if cls_tag not in CLASSES:
            CLASSES.append(cls_tag)

    CLASSES = sorted(CLASSES)
    cls2idx = {cls_tag:cls_idx for cls_idx, cls_tag in enumerate(CLASSES)}
    import pickle
    pickle.dump(cls2idx, open('../external/VGGSound/data/cls2idx.pkl', 'wb'))
    data = pickle.load(open('../external/VGGSound/data/cls2idx.pkl', 'rb'))
    print(data)
    
    vggsound_pat = '/Users/linhy/Desktop/VGGSound/scratch {}/shared/beegfs/hchen/train_data/VGGSound_final/video'
    print('Saving Frames.')
    num_frames = 12
    save_dir = '/Users/linhy/Desktop/VGGSound/shared/frames'
    save_audio_dir = '/Users/linhy/Desktop/VGGSound/shared/audios'
    os.makedirs(save_audio_dir, exist_ok=True) 
    cls_train_dict = {}
    cls_test_dict = {}
    total_num = 0
    # for split_num in range(1):
    #     vs_dir = vggsound_pat.format(split_num)
    #     for file in os.listdir(vs_dir):
    #         if total_num >= 5000:
    #             break
    #         save_path = os.path.join(save_dir, file)
    #         os.makedirs(save_path, exist_ok=True) 
    #         read_save_mp4(os.path.join(vs_dir, file), save_path, num_frames=num_frames)
            
    #         save_audio_path = os.path.join(save_audio_dir, file[:-4])
    #         os.system('ffmpeg -i \"{}\" -acodec pcm_s16le -ac 1 -ar 16000 \"{}.wav\"'.format(os.path.join(vs_dir, file), save_audio_path))
    #         # audio2spectrogram(save_audio_path + '.wav')
            
    #         total_num += 1
    #         # break
    save_dir = '/Users/linhy/Desktop/VGGSound/frames'
    save_audio_dir = '/Users/linhy/Desktop/VGGSound/audios'
    for split_num in [0,2,3]:
        vs_dir = vggsound_pat.format(split_num)
        for file in os.listdir(vs_dir):
            # save_path = os.path.join(save_dir, file)
            # os.makedirs(save_path, exist_ok=True) 
            # read_save_mp4(os.path.join(vs_dir, file), save_path, num_frames=num_frames)
            
            save_audio_path = os.path.join(save_audio_dir, file[:-4])
            os.system('ffmpeg -i \"{}\" -acodec pcm_s16le -ac 1 -ar 16000 \"{}.wav\"'.format(os.path.join(vs_dir, file), save_audio_path))
            # audio2spectrogram(save_audio_path + '.wav')
            # break
            
            if file in train_dict:
                cls_name = train_dict[file]
                cls_dict = cls_train_dict
            elif file in test_dict:
                cls_name = test_dict[file]
                cls_dict = cls_test_dict
            else:
                print('Not in train/test csvs.')
            if cls_name in cls_dict:
                cls_dict[cls_name].append(file)
            else:
                cls_dict[cls_name]= [file,]
    
    cls_train_num_dict = {key:len(val) for key, val in cls_train_dict.items()}
    whole_train_num = sum([val for key, val in cls_train_num_dict.items()])
    print(f'Training Stats cls num{len(cls_train_dict)}, whole num{whole_train_num}')
    # print(cls_train_num_dict)
    
    cls_test_num_dict = {key:len(val) for key, val in cls_test_dict.items()}
    whole_test_num = sum([val for key, val in cls_test_num_dict.items()])
    print(f'Testing Stats cls num{len(cls_test_dict)}, whole num{whole_test_num}')
    # print(cls_test_num_dict)

if __name__ == '__main__':
    data_processing()


# class VGGSoundPair(BaseDataset):
#     """Since there is prepared dataset class in pytorch, we just wrap it here.
#     """
#     # CLASSES = [
#     #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
#     #     'horse', 'ship', 'truck'
#     # ]
#     # /sdc1/hylin/datasets/VGGSound/VGGSound
#     DEFAULT_TRNASFORMS = Compose([ToTensor()])
#     def __init__(self,
#                  root,
#                  video_transforms,
#                  audio_transforms,
#                  target_transforms=None,
#                  classes=None,
#                  ann_file=None,
#                  contrastive=False,
#                  test_mode=False):
#         super(BaseDataset, self).__init__()
#         self.root = expanduser(root)
#         if video_transforms is None:
#             video_transforms = self.DEFAULT_TRNASFORMS
#         self.video_transforms = video_transforms
#         if audio_transforms is None:
#             audio_transforms = self.DEFAULT_TRNASFORMS
#         self.audio_transforms = audio_transforms
#         if target_transforms is not None:
#             target_transforms = target_transforms
#         self.target_transforms = target_transforms
#         self.CLASSES = self.get_classes(classes)
#         self.ann_file = expanduser(ann_file)
#         self.test_mode = test_mode
#         self.num_frame = 12
#         self.data_infos = self.load_annotations()
#         self.contrastive = contrastive
#         # print(len(self.data_infos))

#     def load_annotations(self):
#         # rank, world_size = get_dist_info()
#         # self.root: /home/lhy/datasets/sketchy/rendered_256x256
#         train_csv = './external/VGGSound/data/train.csv'
#         test_csv = './external/VGGSound/data/test.csv'
#         if self.test_mode:
#             data_dict = load_csv(test_csv)
#         else:
#             data_dict = load_csv(train_csv)
#         # print(data_dict)
#         self.CLASSES = []
#         for file in data_dict:
#             cls_tag = data_dict[file]
#             if cls_tag not in self.CLASSES:
#                 self.CLASSES.append(cls_tag)

#         self.CLASSES = sorted(self.CLASSES)
#         video_root = os.path.join(self.root, 'frames')
#         audio_root = os.path.join(self.root, 'audios')
#         print(video_root, audio_root)
#         # just use the dataset to get imgs and set
#         video_sound_pairs = []
#         for video_file in data_dict:
#             video_path = os.path.join(video_root, video_file)
#             if os.path.exists(video_path):
#                 audio_path = os.path.join(audio_root, video_file[:-4] + '.wav')
#                 frames_path = list(sorted(glob.glob(f'{video_path}/*.png')))[:self.num_frame]
#                 # print(frames_path)
#                 video_sound_pairs.append({'frames':frames_path, 'audio':audio_path})
#         data_infos = video_sound_pairs
#         return data_infos

#     def __getitem__(self, idx):
#         frames_path, audio_path = self.data_infos[idx]['frames'], self.data_infos[idx]['audio']
#         frames = [Image.open(frame_path) for frame_path in frames_path]
#         # print([np.array(frame).shape for frame in frames])
#         # print(self.video_transforms)
#         if self.video_transforms is not None:
#             frames = [self.video_transforms(frame) for frame in frames]
#             if self.contrastive:
#                 frames = [torch.stack([frames[l][0] for l in range(len(frames))]), 
#                           torch.stack([frames[l][1] for l in range(len(frames))])]
#             else:
#                 frames = torch.stack(frames)
#         # print('after transfroms', frames.size())
#         audio_gram = audio2spectrogram(audio_path)
#         if self.audio_transforms is not None:
#             audio_gram = self.audio_transforms(audio_gram)
#         return frames, audio_gram



# class VGGSoundVideo(BaseDataset):
#     """Since there is prepared dataset class in pytorch, we just wrap it here.
#     """
#     # CLASSES = [
#     #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
#     #     'horse', 'ship', 'truck'
#     # ]
#     # /sdc1/hylin/datasets/VGGSound/VGGSound
#     DEFAULT_TRNASFORMS = Compose([ToTensor()])
#     def __init__(self,
#                  root,
#                  video_transforms,
#                  target_transforms=None,
#                  classes=None,
#                  ann_file=None,
#                  contrastive=False,
#                  test_mode=False):
#         super(BaseDataset, self).__init__()
#         self.root = expanduser(root)
#         if video_transforms is None:
#             video_transforms = self.DEFAULT_TRNASFORMS
#         self.video_transforms = video_transforms
#         if target_transforms is not None:
#             target_transforms = target_transforms
#         self.target_transforms = target_transforms
#         self.CLASSES = self.get_classes(classes)
#         self.ann_file = expanduser(ann_file)
#         self.test_mode = test_mode
#         self.num_frame = 12
#         self.data_infos = self.load_annotations()
#         self.contrative = contrastive
        
#         # print(len(self.data_infos))

#     def load_annotations(self):
#         # rank, world_size = get_dist_info()
#         # self.root: /home/lhy/datasets/sketchy/rendered_256x256
#         train_csv = './external/VGGSound/data/train.csv'
#         test_csv = './external/VGGSound/data/test.csv'
#         if self.test_mode:
#             data_dict = load_csv(test_csv)
#         else:
#             data_dict = load_csv(train_csv)
        
#         self.CLASSES = []
#         for file in data_dict:
#             cls_tag = data_dict[file]
#             if cls_tag not in self.CLASSES:
#                 self.CLASSES.append(cls_tag)

#         self.CLASSES = sorted(self.CLASSES)
#         cls2idx = {cls_tag:cls_idx for cls_idx, cls_tag in enumerate(self.CLASSES)}

#         video_root = os.path.join(self.root, 'frames')
        
#         # just use the dataset to get imgs and set
#         video_label_pairs = []
#         for video_file in data_dict:
#             video_path = os.path.join(video_root, video_file)
#             if os.path.exists(video_path):
#                 frames_path = list(sorted(glob.glob(f'{video_path}/*.png')))[:self.num_frame]
#                 video_label_pairs.append({'frames':frames_path, 'label':cls2idx[data_dict[video_file]]})
#         data_infos = video_label_pairs
#         return data_infos

#     def __getitem__(self, idx):
#         frames_path, label = self.data_infos[idx]['frames'], self.data_infos[idx]['label']
#         frames = [Image.open(frame_path) for frame_path in frames_path]
#         if self.video_transforms is not None:
#             frames = [self.video_transforms(frame) for frame in frames]
#             if self.contrastive:
#                 frames = [torch.stack([frames[l][0] for l in range(len(frames))]), 
#                           torch.stack([frames[l][1] for l in range(len(frames))])]
#             else:
#                 frames = torch.stack(frames)
        
#         if self.target_transforms is not None:
#             label = self.target_transforms(label)
#         return frames, label


# class VGGSoundAudio(BaseDataset):
#     """Since there is prepared dataset class in pytorch, we just wrap it here.
#     """
#     # CLASSES = [
#     #     'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
#     #     'horse', 'ship', 'truck'
#     # ]
#     # /sdc1/hylin/datasets/VGGSound/VGGSound
#     DEFAULT_TRNASFORMS = Compose([ToTensor()])
#     def __init__(self,
#                  root,
#                  audio_transforms,
#                  target_transforms=None,
#                  classes=None,
#                  ann_file=None,
#                  contrastive=False,
#                  test_mode=False):
#         super(BaseDataset, self).__init__()
#         self.root = expanduser(root)
#         if audio_transforms is None:
#             audio_transforms = self.DEFAULT_TRNASFORMS
#         self.audio_transforms = audio_transforms
#         if target_transforms is not None:
#             target_transforms = target_transforms
#         self.target_transforms = target_transforms
#         self.CLASSES = self.get_classes(classes)
#         self.ann_file = expanduser(ann_file)
#         self.test_mode = test_mode
#         self.data_infos = self.load_annotations()
#         self.contrastive = contrastive
#         # print(len(self.data_infos))

#     def load_annotations(self):
#         # rank, world_size = get_dist_info()
#         # self.root: /home/lhy/datasets/sketchy/rendered_256x256
#         train_csv = './external/VGGSound/data/train.csv'
#         test_csv = './external/VGGSound/data/test.csv'
#         if self.test_mode:
#             data_dict = load_csv(test_csv)
#         else:
#             data_dict = load_csv(train_csv)
        
#         self.CLASSES = []
#         for file in data_dict:
#             cls_tag = data_dict[file]
#             if cls_tag not in self.CLASSES:
#                 self.CLASSES.append(cls_tag)

#         self.CLASSES = sorted(self.CLASSES)
#         cls2idx = {cls_tag:cls_idx for cls_idx, cls_tag in enumerate(self.CLASSES)}

#         audio_root = os.path.join(self.root, 'audios')
        
#         # just use the dataset to get imgs and set
#         audio_label_pairs = []
#         for video_file in data_dict:
#             audio_path = os.path.join(audio_root, video_file[:-4]+'.wav')
#             if os.path.exists(audio_path):
#                 audio_label_pairs.append({'audio':audio_path, 'label':cls2idx[data_dict[video_file]]})
#         data_infos = audio_label_pairs
        
#         return data_infos

#     def __getitem__(self, idx):
#         audio_path, label = self.data_infos[idx]['audio'], self.data_infos[idx]['label']
#         audio_gram = audio2spectrogram(audio_path)
#         if self.audio_transforms is not None:
#             audio_gram = self.audio_transforms(audio_gram)
#         if self.target_transforms is not None:
#             label = self.target_transforms(label)
#         return audio_gram, label


