import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split


class MTDataset(Dataset):
    def __init__(self, path):
        with open(path, "rb") as file:
            data = pickle.load(file)
        # self.data_size = data.n_transitions_stored
        self.data_size = len(data)
        # dataset = data.sample_transitions(self.data_size)
        self.dataset = self.to_torch_dict(data)
        # rew = self.dataset['reward']
        self.rtg = discount_cumsum(self.dataset, gamma=0.9)
        # self.obs = self.dataset['observation']
        # self.act = self.dataset['action']

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        return [self.dataset[idx]['observations'],
                self.dataset[idx]['actions'],
                self.rtg[idx]]

    def to_torch_dict(self, data):
        data_sample = []
        for element in data:
            for key, value in element.items():
                tensor = torch.from_numpy(value)
                if tensor.dtype != torch.float32:
                    tensor = tensor.float()
                element[key] = tensor
            data_sample.append(element)
        return data_sample

    # def to_torch_dict(self, data_sample):
    #     for key, value in data_sample.items():
    #         tensor = torch.from_numpy(value)
    #         if tensor.dtype != torch.float32:
    #             tensor = tensor.float()
    #         data_sample[key] = tensor
    #     return data_sample

    def split_data(self, n_test=0.33):
        test_size = round(n_test * len(self.obs))
        train_size = len(self.obs) - test_size
        return random_split(self, [train_size, test_size])


# dataset = MTDataset("halfcheetah-expert-v2.pkl")
