import torch

from torchvision import transforms
from torch.utils.data import Dataset
import h5py
import numpy as np


def select_by_index(x, index):
    x_ndim = x.ndim
    index_ndim = index.ndim
    index = index.reshape(list(index.shape) + [1] * (x_ndim - index_ndim))
    index = index.expand([-1] * index_ndim + list(x.shape[index_ndim:]))
    x = torch.gather(x, index_ndim - 1, index)
    return x


class GlobVideoDataset(Dataset):
    def __init__(self, root, phase, img_size, ep_len=4):
        self.root = root
        self.img_size = img_size
        self.ep_len = ep_len
        with h5py.File(root, 'r', libver='latest', swmr=True) as f:
            self.total_data =f[phase]
            
            self.video = self.total_data['image'][()]
            self.seg = self.total_data['segment'][()]

            num_video, num_view = self.video.shape[:2]

            if phase == 'test':
                # Fix views for test
                if num_view == 20:
                    index = [0, 3, 6, 9, 12, 15]
                else:
                    index = [0, 2, 4, 6, 8, 9]
                self.video = self.video[:, index]
                self.seg = self.seg[:, index]
            else:
                # Random views for train
                noise = torch.rand([num_video,  num_view])
                index = torch.argsort(noise, dim=1)
                index = index[:, :ep_len]
                self.video = np.array(select_by_index(torch.tensor(self.video), index))
                self.seg = np.array(select_by_index(torch.tensor(self.seg), index))

        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.video)

    def __getitem__(self, idx):
        video = self.video[idx]
        segs = self.seg[idx]
        all_image = []
        all_seg = []
        for image, seg in zip(video, segs):
            all_image.append(self.transform(image))
            all_seg.append(torch.from_numpy(seg).to(torch.int64))
        video = torch.stack(all_image, dim=0)
        segs = torch.stack(all_seg, dim=0)
        return video, segs
