from torch.utils.data.dataset import Dataset
from torch.utils.data.dataset import ConcatDataset
from torchvision import transforms
import pickle
import glob
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import sounddevice as sd


class RSI3Dataset(Dataset):

    def __init__(self, picklePath, config, **kwargs):
        self.filePath=picklePath
        self.config=config
        with open(self.filePath, 'rb') as f:
            self.ground_truth_pair = pickle.load(f)
        self.audio = kwargs['audio']

        if config.name=='AI2ThorConfig':
            from Envs.ai2thor.RSI2.RL_env_RSI2 import Task
            self.Task=Task

            # task list
            self.tl = []

            for loc in self.config.allTasks:
                for obj in self.config.allTasks[loc]:
                    for act in self.config.allTasks[loc][obj]:
                        t = Task(loc=loc, obj=obj, act=act)
                        self.tl.append(t)

        self.img_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        return

    def getImgSoundPair(self, gt):
        """
        choose audio according to ground_truth
        :return:
        """
        sound_positive = None
        if gt==self.config.taskNum:
            sound_positive=np.zeros(shape=self.config.sound_dim)
        else:
            if self.config.name == 'AI2ThorConfig':
                pos_tsk = self.tl[gt]
                sound_positive, positive_audio, _ = self.audio.getAudioFromTask(torch, pos_tsk, self.Task)
            else:
                sound_positive, positive_audio = self.audio.genSoundFeat(objIndx=gt, featType='MFCC',
                                                                         rand_fn=torch.randint)
        return sound_positive

    def __getitem__(self, index):
        # assume the channel of the image in the dataset is the first dimension
        image=torch.from_numpy(self.ground_truth_pair[index]['image'])
        image = (image / 255.).float()

        gt=int(self.ground_truth_pair[index]['ground_truth'])

        sound_positive=self.getImgSoundPair(gt)

        return [image], \
               [sound_positive],\
               gt

    def __len__(self):
        return len(self.ground_truth_pair)


class RSI3FineTuneDataset(RSI3Dataset):
    """
    During the fine tuning, we don't have labels, so we cannot randomly associate an image with a sound
    Instead, we will associate an image with a sound one time in __init__ and this association will not change
    during the training
    """
    def __init__(self, picklePath, config, **kwargs):
        RSI3Dataset.__init__(self, picklePath, config, **kwargs)

        for item in self.ground_truth_pair:
            sound_positive=self.getImgSoundPair(int(item['ground_truth']))
            item['sound_positive']=sound_positive

    def __getitem__(self, index):
        # assume the channel of the image in the dataset is the first dimension
        image = torch.from_numpy(self.ground_truth_pair[index]['image'])
        image = (image / 255.).float()

        gt = int(self.ground_truth_pair[index]['ground_truth'])
        sound_positive = self.ground_truth_pair[index]['sound_positive']

        return [image], \
               [sound_positive], \
               gt


class RSI2Dataset(Dataset):

    def __init__(self, picklePath, config, **kwargs):
        self.filePath=picklePath
        self.config=config
        with open(self.filePath, 'rb') as f:
            self.ground_truth_pair = pickle.load(f)
        self.audio = kwargs['audio']

        if config.name == 'AI2ThorConfig':
            from Envs.ai2thor.RSI2.RL_env_RSI2 import Task
            self.Task = Task

            # task list
            self.tl = []

            for loc in self.config.allTasks:
                for obj in self.config.allTasks[loc]:
                    for act in self.config.allTasks[loc][obj]:
                        t = Task(loc=loc, obj=obj, act=act)
                        self.tl.append(t)
        return

    def getImgSoundPair(self, gt, sn_id):
        sound_positive = None

        if gt == self.config.taskNum:
            sound_positive = np.zeros(shape=self.config.sound_dim)

            neg_tsk = self.tl[sn_id]
            sound_negative, negaitve_audio, _ = self.audio.getAudioFromTask(torch, neg_tsk, self.Task)
        else:
            if self.config.name == 'AI2ThorConfig':
                pos_tsk = self.tl[gt]
                sound_positive, positive_audio, _ = self.audio.getAudioFromTask(torch, pos_tsk, self.Task)

                if sn_id == self.config.taskNum:
                    sound_negative = np.zeros(shape=self.config.sound_dim)
                else:
                    neg_tsk = self.tl[sn_id]
                    sound_negative, negaitve_audio, _ = self.audio.getAudioFromTask(torch, neg_tsk, self.Task)
            else:
                raise NotImplementedError

        return sound_positive, sound_negative

    def __getitem__(self, index):
        # assume the channel of the image in the dataset is the first dimension
        # assume the sound has shape (1, frame, features)
        image=torch.from_numpy(self.ground_truth_pair[index]['image'])
        image=(image/255.).float()

        # new format dataset
        # choose audio according to ground_truth
        gt = int(self.ground_truth_pair[index]['ground_truth'])
        if 'sound_negative' not in self.ground_truth_pair[index]:
            sn_id = int(self.ground_truth_pair[index]['sound_negative_id'])
            sound_positive, sound_negative=self.getImgSoundPair(gt, sn_id)
        else:
            sound_positive = self.ground_truth_pair[index]['sound_positive']
            sound_negative = self.ground_truth_pair[index]['sound_negative']

        return image, \
               sound_positive,\
               sound_negative,\
               gt

    def __len__(self):
        return len(self.ground_truth_pair)


class RSI2FineTuneDataset(RSI2Dataset):
    """
    During the fine tuning, we don't have labels, so we cannot randomly associate an image with a sound
    Instead, we will associate an image with a sound one time in __init__ and this association will not change
    during the training
    """
    def __init__(self, picklePath, config, **kwargs):
        RSI2Dataset.__init__(self, picklePath, config, **kwargs)

        for item in self.ground_truth_pair:
            if 'sound_negative' not in item: # assume sound_negative_id will be provided
                sound_positive, sound_negative=self.getImgSoundPair(int(item['ground_truth']), int(item['sound_negative_id']))
                item['sound_positive']=sound_positive
                item['sound_negative']=sound_negative
            else: # assume sound_positive and sound_negative are provided
                pass

    def __getitem__(self, index):
        # assume the channel of the image in the dataset is the first dimension
        image = torch.from_numpy(self.ground_truth_pair[index]['image'])
        image = (image / 255.).float()

        gt = int(self.ground_truth_pair[index]['ground_truth'])
        sound_positive = self.ground_truth_pair[index]['sound_positive']
        sound_negative = self.ground_truth_pair[index]['sound_negative']

        return image, \
               sound_positive, \
               sound_negative, \
               gt


class RSI1Dataset(Dataset):

    def __init__(self, picklePath, config, **kwargs):
        self.filePath=picklePath
        self.config=config
        with open(self.filePath, 'rb') as f:
            self.ground_truth_pair = pickle.load(f)
        self.audio = kwargs['audio']

        if config.name=='AI2ThorConfig':
            from Envs.ai2thor.RSI2.RL_env_RSI2 import Task
            self.Task=Task

            # task list
            self.tl = []

            for loc in self.config.allTasks:
                for obj in self.config.allTasks[loc]:
                    for act in self.config.allTasks[loc][obj]:
                        t = Task(loc=loc, obj=obj, act=act)
                        self.tl.append(t)
        return

    def getImgSoundPair(self, gt):
        goal_sound = None
        if gt == self.config.taskNum:
            goal_sound = np.zeros(shape=self.config.sound_dim)
        else:
            if self.config.name == 'AI2ThorConfig':
                pos_tsk = self.tl[gt]
                goal_sound, positive_audio, _ = self.audio.getAudioFromTask(torch, pos_tsk, self.Task)
            else:
                goal_sound, positive_audio = self.audio.genSoundFeat(objIndx=gt, featType='MFCC',
                                                                         rand_fn=torch.randint)
        return goal_sound

    def __getitem__(self, index):
        # assume the channel of the image in the dataset is the first dimension
        # assume the sound has shape (1, frame, features)
        image=torch.from_numpy(self.ground_truth_pair[index]['image'])
        image = (image / 255.).float()
        # when we pair the sound, we use goal_sound_label
        goal_sound_idx=np.argmax(self.ground_truth_pair[index]['goal_sound_label'])
        goal_sound=self.getImgSoundPair(goal_sound_idx)
        return image, \
               goal_sound, \
               self.ground_truth_pair[index]['goal_sound_label'], \
               self.ground_truth_pair[index]['ground_truth'], \
               self.ground_truth_pair[index]['inSight'],\
               self.ground_truth_pair[index]['exi'], \

    def __len__(self):
        return len(self.ground_truth_pair)


class featureDataset(Dataset):

    def __init__(self, data, config):
        """
        The data point is a vector embedding from the VAR. Used for training the linear model for VAR evaluation
        :param data: tensor of shape (num_task, number of data, representation dim)
        """
        self.config=config
        # data [{'img':img feature points for task id 0, 'sound':sound feature points for task id 0},{},{}, ...]
        img_data=[]
        sound_data=[]
        for i in range(len(data)):
            task_feat=data[i]
            img_gt=np.concatenate([task_feat['img'], i*np.ones((len(task_feat['img']), 1))], axis=1)
            sound_gt = np.concatenate([task_feat['sound'], i * np.ones((len(task_feat['sound']), 1))], axis=1)
            img_data.append(img_gt)
            sound_data.append(sound_gt)
        self.img_data=np.concatenate(img_data, axis=0)
        self.sound_data=np.concatenate(sound_data, axis=0)

    def __getitem__(self, index):
        img_data = self.img_data[index, :-1]
        img_gt = self.img_data[index, -1]

        sound_data = self.sound_data[index, :-1]
        sound_gt = self.sound_data[index, -1]

        return img_data, sound_data, img_gt, sound_gt

    def __len__(self):
        return len(self.img_data)


def loadEnvData(data_dir, config, batch_size, shuffle, num_workers, drop_last, loadNum=None,
                dtype=RSI2Dataset, train_test='train'):
    # load audio dataset
    from Envs.audioLoader import audioLoader
    audio = audioLoader(config=config)
    audio.loadData()
    all_datasets = []
    for i,dirs in enumerate(data_dir):
        assert os.path.exists(dirs)
        path=os.path.join(dirs, train_test)
        if loadNum is None or loadNum[i]=='all':
            for filePath in glob.glob(os.path.join(path, '*.pickle')):
                all_datasets.append(dtype(picklePath=filePath, config=config, audio=audio))
        else:
            fileList=glob.glob(os.path.join(path, '*.pickle'))
            if len(fileList)>int(loadNum[i]):
                fileList=np.random.choice(fileList, size=int(loadNum[i]))
            for filePath in fileList:
                all_datasets.append(dtype(picklePath=str(filePath), config=config, audio=audio))

    final_dataset = ConcatDataset(all_datasets)
    generator = torch.utils.data.DataLoader(final_dataset,
                                            batch_size=batch_size,
                                            shuffle=shuffle,
                                            num_workers=num_workers,
                                            pin_memory=True,
                                            drop_last=drop_last)
    num=[0]*(config.taskNum+1)
    for dataset in final_dataset.datasets:
        for pairs in dataset.ground_truth_pair:
            num[int(pairs['ground_truth'])]=num[int(pairs['ground_truth'])]+1
    print("The number of pairs for each object in the dataset is:", num)
    return generator, final_dataset


# custom transform
class AddGaussianNoise(object):

    def __init__(self, mean=0., std=1.0):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
