import torch
from torch.utils.data import Dataset


class DiscriminativeDataset(Dataset):

    def __init__(self, data_real: Dataset, data_synthetic: Dataset) -> None:
        super().__init__()
        self.data = torch.utils.data.ConcatDataset([data_real, data_synthetic])

    def __getitem__(self, item: int) -> torch.Tensor:
        return self.data[item]

    def __len__(self) -> int:
        return len(self.data)
