from typing import Union

import numpy as np
import torch
from torch.utils.data import Dataset

import utils


class MultimodalDataset(Dataset):
    """ Dataset for multimodal data. """

    def __init__(self,
                 x1: torch.tensor,
                 x2: torch.tensor,
                 y: Union[np.array, torch.tensor]):
        self.x = [x1, x2]
        # supplementary information
        self.s = {'y': utils.to_torch(y, dtype=torch.int32)}
        self.len = len(self.s['y'])
        assert all([len(v) == self.len for v in self.x])

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        x = [v[idx] for v in self.x]
        s = {k: v[idx] for k, v in self.s.items()}
        return x, s
