import os
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class AVEDataset(Dataset):
    def __init__(self, data_root, split='train'):
        super(AVEDataset, self).__init__()
        self.split = split
        #self.visual_feature_path = os.path.join(data_root, 'visual_feature.h5')
        self.visual_feature_path = os.path.join(data_root, 'AVE_Resnet151_features.npy')
        self.audio_feature_path = os.path.join(data_root, 'audio_feature.h5')
        self.noisy_visual_feature_path = os.path.join(data_root, 'visual_feature_noisy.h5') # only background
        self.noisy_audio_feature_path = os.path.join(data_root, 'audio_feature_noisy.h5')   # only background
        # Now for the supervised task
        self.labels_path = os.path.join(data_root, 'labels.h5') # original labels for testing
        self.dir_labels_path = os.path.join(data_root, 'mil_labels.h5')  # video-level labels
        self.dir_labels_bg_path = os.path.join(data_root, 'labels_noisy.h5')  # only background
        self.sample_order_path = os.path.join(data_root, f'{split}_order.h5')
        self.h5_isOpen = False

    def __getitem__(self, index):
        if not self.h5_isOpen:
            self.visual_feature = np.load(self.visual_feature_path, mmap_mode=None)
            self.audio_feature = h5py.File(self.audio_feature_path, 'r')['avadataset']
            self.labels = h5py.File(self.labels_path, 'r')['avadataset']
            self.sample_order = h5py.File(self.sample_order_path, 'r')['order']
            self.h5_isOpen = True
        sample_index = self.sample_order[index]
        visual_feat = self.visual_feature[sample_index]
        audio_feat = self.audio_feature[sample_index]
        label = self.labels[sample_index]

        return visual_feat, audio_feat, label

    def __len__(self):
        f = h5py.File(self.sample_order_path, 'r')
        sample_num = len(f['order'])
        f.close()
        return sample_num




