
import h5py
import torch
from torch.utils.data import Dataset,DataLoader

class VSdataset(Dataset):
    def __init__(self,data,video_nums,transform=None):
        features,dataset_names = data
        self.features = features
        self.dataset_names = dataset_names
        self.video_nums = video_nums
        self.transform = transform
    def __len__(self):
        return self.video_nums
    def __getitem__(self,idx):
        output_feature = torch.from_numpy(self.features[idx]).float()
        expanded_feature = output_feature.unsqueeze(0).expand(3,-1,-1) #unsuqeeze (T,D) to (3,T,D)
        if self.transform is not None:
            output_feature=  self.transform(output_feature)
            expanded_feature=  self.transform(expanded_feature)
        return torch.unsqueeze(output_feature,0),torch.unsqueeze(expanded_feature,0),self.dataset_names[idx]

def collate_fn(sample):
    return sample[0]

def create_ssl_loader(datasets:list=['SumMe','TVSum','OVP','Youtube']):
    features = []
    dataset_names = []
    video_nums = 0
    
    for dataset in datasets:
        data_path = f'./data/eccv16_dataset_{dataset.lower()}_google_pool5.h5'
        with h5py.File(data_path,'r') as hdf:
            for video in hdf:
                features.append(hdf[video]['features'][()])
                dataset_names.append(dataset)
                video_nums+=1
                
    ssl_data = (features,dataset_names) 
    ssl_dataset = VSdataset(data=ssl_data,video_nums=video_nums)
    ssl_loader = DataLoader(ssl_dataset,batch_size=1,shuffle=True,collate_fn=collate_fn) # video has different length
    print(f"SSL dataset size: {len(ssl_loader.dataset)}")
    return ssl_loader