import torch

from torch.utils.data import Dataset


class TokenDataset(Dataset):
    def __init__(self, raw_data, history_length, frame_skip, device):
        raw_data = torch.from_numpy(raw_data[::frame_skip]).type(torch.float64)

        self.length = raw_data.shape[0] - history_length
        self.frame_skip = frame_skip

        tokens_per_frame = raw_data.shape[1]
        tokens_per_datum = tokens_per_frame * history_length
        data = torch.empty((self.length, tokens_per_datum), dtype=torch.float64)
        labels = torch.empty((self.length, tokens_per_frame), dtype=torch.float64)

        for i in range(self.length):
            x = raw_data[i:i + history_length].flatten()
            data[i, :] = x
            labels[i] = raw_data[i + history_length]

        self.device = device
        self.history_length = history_length
        self.data = data.to(device)
        self.labels = labels.to(device)
        self.raw_data = raw_data

    def set_resolution_labels(self, u):
        assert self.history_length % 2 == 1
        # Input is not symmetrical about the middle frame if history is even

        middle = self.history_length // 2
        u = u[::self.frame_skip]
        labels = u[middle:self.length + middle]
        self.labels = torch.tensor(labels, dtype=torch.float64).to(self.device)

    def reset_token_labels(self, history_length):
        self.history_length = history_length
        self.length = self.raw_data.shape[0] - history_length
        tokens_per_frame = self.raw_data.shape[1]
        tokens_per_datum = tokens_per_frame * history_length
        data = torch.empty((self.length, tokens_per_datum), dtype=torch.float64)
        labels = torch.empty((self.length, tokens_per_frame), dtype=torch.float64)

        for i in range(self.length):
            x = self.raw_data[i:i + history_length].flatten()
            data[i, :] = x
            labels[i] = self.raw_data[i + history_length]

        self.data = data.to(self.device)
        self.labels = labels.to(self.device)

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        i = i % self.length
        return self.data[i, :], self.labels[i]


class TokenConcatDataset(torch.utils.data.ConcatDataset):
    def __init__(self, datasets, device):
        super().__init__(datasets)
        for dataset in datasets:
            if type(dataset) is not TokenDataset:
                raise TypeError
        self.data = torch.cat([dataset.data for dataset in datasets], dim=0).to(device)
        self.labels = torch.cat([dataset.labels for dataset in datasets], dim=0).to(device)
        self.length = self.data.shape[0]
        self.device = device

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        i = i % self.length
        return self.data[i, :], self.labels[i]