from torch.utils.data import Dataset
import torch
import numpy as np

__all__ = ['MMDataset', 'AuGDataset']

class MMDataset(Dataset):
        
    def __init__(self, label_ids, text_data, video_data, audio_data, other_data, concepts_path):
        
        self.label_ids = label_ids
        self.text_data = text_data
        self.video_data = video_data
        self.audio_data = audio_data
        if concepts_path is not None:
            self.text_concepts = torch.load(concepts_path[0])
            self.audio_concepts = torch.load(concepts_path[1])
            self.video_concepts = torch.load(concepts_path[2])
            self.use_concept = True
        else:
            self.use_concept = False
        self.size = len(self.text_data)

        self.other_data = other_data
        if self.other_data is not None:
            for key in other_data.keys():
                setattr(self, key, other_data[key])  
    
    def __len__(self):
        return self.size

    def __getitem__(self, index):

        sample = {
            'label_ids': torch.tensor(self.label_ids[index]), 
            'text_feats': torch.tensor(self.text_data[index]),
            'video_feats': torch.tensor(np.array(self.video_data['feats'][index])),
            'video_lengths': torch.tensor(np.array(self.video_data['lengths'][index])),
            'audio_feats': torch.tensor(np.array(self.audio_data['feats'][index])),
            'audio_lengths': torch.tensor(np.array(self.audio_data['lengths'][index]))
        }
        if self.use_concept:
            sample.update(
        {
            'text_concepts': self.text_concepts[index],
            'video_concepts': self.video_concepts[index],
            'audio_concepts': self.audio_concepts[index]
        })

        if self.other_data is not None:    
            for key in self.other_data.keys():
                sample[key] = torch.tensor(getattr(self, key)[index])
        
        return sample
