
import numpy as np
from torch.utils.data import Dataset


class DatasetWrapper(Dataset):
    """
    Wraps a torch vision dataset to be usable with the pytorch_wrapper.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        inputs, actions = self.dataset[item]
        return {'inputs': {'pov': inputs[0], 'actions': inputs[1]},
                'targets': {'actions': actions}}


class DatasetWrapperOriginalSpace(Dataset):
    """
    Wraps a torch vision dataset to be usable with the pytorch_wrapper.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        inputs, actions = self.dataset[item]

        # collect inputs
        input_dict = {'pov': inputs[0],
                      'binary_actions': inputs[1][:, 0:-2],
                      'camera_actions': inputs[1][:, -2::],
                      'rewards': inputs[2]}

        # collect targets
        target_dict = {'binary_actions': actions[0:-2],
                       'camera_actions': actions[-2::],
                       'values': np.sum(inputs[2], axis=0)[np.newaxis, ...]}

        return {"inputs": input_dict, "targets": target_dict}
