from typing import Dict, List, Tuple, Union

import numpy as np
import torch

from offline_rl.data.index_utils import cumulative_index_to_multi_index

# Only int is valid for now.
LabelType = Union[int]


class LabeledMergeDataset(torch.utils.data.Dataset):
    """A dataset that merges other torch Datasets and adds a discrete label to the output.

    Currently only uniform sampling from the datasets is implemented.

    Args:
        datasets: A list of tuples of keys and torch datasets. Whatever the key is for a given dataset,
            that is what the label will be when sampling from that dataset. The key must be of
            a type specified in the `LabelType`. This is a list of tuples instead of a mapping so that
            the order is deterministic.
        label_key: The key used in the return sample to indicate the label.
    """
    def __init__(self, datasets: List[Tuple[LabelType, torch.utils.data.Dataset]], label_key: str = "label"):
        self.labels = [label_dataset_pair[0] for label_dataset_pair in datasets]
        self.datasets = [label_dataset_pair[1] for label_dataset_pair in datasets]
        lengths = [len(d) for d in self.datasets]
        assert all(l > 0 for l in lengths), "Empty datasets are not allowed"
        self.length = sum(lengths)
        self.cumulative_lengths = np.cumsum(lengths)
        self.label_key = label_key

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, index: int) -> Dict:
        dataset_index, subindex = cumulative_index_to_multi_index(index, self.cumulative_lengths)
        sample = self.datasets[dataset_index][subindex]
        assert isinstance(sample, Dict), "Only dictionary-based torch datasets are implemented"
        assert self.label_key not in sample
        sample[self.label_key] = torch.FloatTensor([self.labels[dataset_index]])
        return sample
