import os
import csv
import random
from torch.utils.data import Dataset
import torch
import numpy as np
from PIL import Image

from torchvision import transforms
from .helper import KINETICS_SOUND_CLASSES, AVE_CLASSES, CREMA_D_CLASSES, OPEN_CLIP_MEAN, OPEN_CLIP_STD
from imagebind import data

import warnings

class AudioDataset(Dataset):
    def __init__(self, train=True, dataset='ks'):
        assert dataset in ['ks', 'AVE', 'CREMA-D']
        if dataset == 'ks':
            list_path = 'data/train_kinetics.csv'
            self.root_path = '/localdata_ssd/av_dataset/Kinetics-sound/audios'
            label_names = KINETICS_SOUND_CLASSES
        elif dataset == 'AVE':
            list_path = 'data/train_AVE.csv'
            self.root_path = '/localdata_ssd/av_dataset/dataset_ave/audios'
            label_names = AVE_CLASSES
        else:
            list_path = 'data/train_CREMA-D.csv'
            self.root_path = '/localdata_ssd/av_dataset/CREMA-D/audios'
            label_names = CREMA_D_CLASSES
        if not train:
            list_path = list_path.replace('train', 'test')

        self.data = []
        with open(list_path) as csvfile:
            csv_reader = csv.reader(csvfile)
            for row in csv_reader:
                self.data.append(row)
        print("total data size:", len(self.data))
        
        self.label_name_2_id = { name:i
            for i, name in enumerate(label_names)
        }
        # print(self.data[0][0])
    
    def _get_label(self, path):
        label_name = path.split('/')[1]
        label = self.label_name_2_id[label_name]
        label = torch.from_numpy(np.array(label)).long()
        return label


    def __getitem__(self, index):
        audio_path = self.data[index][0]
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            audio_input = data.load_and_transform_audio_data([self.root_path+audio_path], "cpu")
        audio_input = audio_input[0]

        label = self._get_label(audio_path)

        return audio_input, label 

    def __len__(self):
        return len(self.data)


class VideoDataset(Dataset):
    def __init__(self, train=True, dataset='ks', use_num_frames=3):
        self.use_num_frames = use_num_frames
        self.train = train
        assert dataset in ['ks', 'AVE', 'CREMA-D']
        if dataset == 'ks':
            list_path = 'data/train_kinetics.csv'
            self.root_path = '/localdata_ssd/av_dataset/Kinetics-sound/frames'
            label_names = KINETICS_SOUND_CLASSES
        elif dataset == 'AVE':
            list_path = 'data/train_AVE.csv'
            self.root_path = '/localdata_ssd/av_dataset/dataset_ave/frames'
            label_names = AVE_CLASSES
        else:
            list_path = 'data/train_CREMA-D.csv'
            self.root_path = '/localdata_ssd/av_dataset/CREMA-D/frames'
            label_names = CREMA_D_CLASSES
        if not train:
            list_path = list_path.replace('train', 'test')
        
        self.label_name_2_id = { name:i
            for i, name in enumerate(label_names)
        }

        self.data = []
        with open(list_path) as csvfile:
            csv_reader = csv.reader(csvfile)
            for row in csv_reader:
                self.data.append(row)
        
        if self.train:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(OPEN_CLIP_MEAN, OPEN_CLIP_STD)
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(OPEN_CLIP_MEAN, OPEN_CLIP_STD)
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        print("total data size:", len(self.data))


    def __len__(self):
        return len(self.data)
    
    def _get_label(self, path):
        label_name = path.split('/')[1]
        label = self.label_name_2_id[label_name]
        label = torch.from_numpy(np.array(label)).long()
        return label

    def __getitem__(self, index):
        label = self._get_label(self.data[index][0])

        video_path = self.root_path + self.data[index][1]
        all_images = os.listdir(video_path)
        if self.train:
            if len(all_images) < self.use_num_frames:
                selected_images_paths = [random.choice(all_images) for _ in range(self.use_num_frames)]
            else:
                selected_images_paths = random.sample(all_images, self.use_num_frames)
            selected_images_paths.sort()
        else:
            if len(all_images) < self.use_num_frames:
                selected_images_paths = [random.choice(all_images) for _ in range(self.use_num_frames)]
            else:
                selected_images_paths = [all_images[0], all_images[int(len(all_images)/2)], all_images[-1]]
        imgs = []
        for path in selected_images_paths:
            img = Image.open(os.path.join(video_path, path)).convert('RGB')
            img = self.transform(img)
            imgs.append(img)
        imgs = torch.stack(imgs)#3,3,224,224
        return imgs, label

class MultimodalDataset(Dataset):
    def __init__(self, train=True, dataset='ks', use_num_frames=3) -> None:
        self.use_num_frames = use_num_frames
        self.train = train
        assert dataset in ['ks', 'AVE', 'CREMA-D']
        if dataset == 'ks':
            list_path = 'data/train_kinetics.csv'
            self.video_root_path = '/localdata_ssd/av_dataset/Kinetics-sound/frames'
            label_names = KINETICS_SOUND_CLASSES
        elif dataset == 'AVE':
            list_path = 'data/train_AVE.csv'
            self.video_root_path = '/localdata_ssd/av_dataset/dataset_ave/frames'
            label_names = AVE_CLASSES
        else:
            list_path = 'data/train_CREMA-D.csv'
            self.video_root_path = '/localdata_ssd/av_dataset/CREMA-D/frames'
            label_names = CREMA_D_CLASSES
        if not train:
            list_path = list_path.replace('train', 'test')
        
        self.audio_root_path = self.video_root_path.replace('frames','audios')
        
        self.label_name_2_id = { name:i
            for i, name in enumerate(label_names)
        }

        self.data = []
        with open(list_path) as csvfile:
            csv_reader = csv.reader(csvfile)
            for row in csv_reader:
                self.data.append(row)
        
        if self.train:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(OPEN_CLIP_MEAN, OPEN_CLIP_STD)
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(OPEN_CLIP_MEAN, OPEN_CLIP_STD)
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        print("total data size:", len(self.data))

    def _get_label(self, path):
        label_name = path.split('/')[1]
        label = self.label_name_2_id[label_name]
        label = torch.from_numpy(np.array(label)).long()
        return label

    def __len__(self):
        return len(self.data)
    
    def _get_audio(self, index):
        audio_path = self.data[index][0]
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            audio_input = data.load_and_transform_audio_data([self.audio_root_path+audio_path], "cpu")
        audio_input = audio_input[0]
        return audio_input#3,1,128,204
    
    def _get_frame(self, index):
        video_path = self.video_root_path + self.data[index][1]
        all_images = os.listdir(video_path)
        if self.train:
            if len(all_images) < self.use_num_frames:
                selected_images_paths = [random.choice(all_images) for _ in range(self.use_num_frames)]
            else:
                selected_images_paths = random.sample(all_images, self.use_num_frames)
            selected_images_paths.sort()
        else:
            if len(all_images) < self.use_num_frames:
                selected_images_paths = [random.choice(all_images) for _ in range(self.use_num_frames)]
            else:
                selected_images_paths = [all_images[0], all_images[int(len(all_images)/2)], all_images[-1]]
        imgs = []
        for path in selected_images_paths:
            img = Image.open(os.path.join(video_path, path)).convert('RGB')
            img = self.transform(img)
            imgs.append(img)
        imgs = torch.stack(imgs)#3,3,224,224
        return imgs

    def __getitem__(self, index):
        audios = self._get_audio(index)
        frames = self._get_frame(index)

        audio_path = self.data[index][0]
        labels = self._get_label(audio_path)

        return audios, frames, labels


if __name__ == "__main__":
    dataset = MultimodalDataset()
    a, f, l = dataset[0]
    print(a.shape, f.shape, l)