#! -*- coding: utf-8
import typing

import numpy as np
import torch


class PTBDataset(torch.utils.data.Dataset):
    def __init__(self, datafile: str, pad_token_id: int = 0, delimiter: str = " "):
        super().__init__()
        self.datafile = datafile
        self.pad_token_id = pad_token_id
        self.delimiter = delimiter

        self.tokens, self.token_length, self.max_length = self.read_datafile()

    def read_datafile(self) -> typing.Tuple[torch.Tensor, torch.Tensor, int]:
        with open(self.datafile) as f:
            lines = f.readlines()
        tokens = []
        for line in lines:
            t = [int(t) for t in line.strip().split(self.delimiter)]
            tokens.append(t)
        lengths = [len(t) for t in tokens]
        max_length = int(np.max(lengths))

        padded = []
        for t in tokens:
            pad_size = max_length - len(t)
            padded.append(t + [self.pad_token_id] * pad_size)
        padded = torch.Tensor(np.array(padded)).long()
        lengths = torch.Tensor(np.array(lengths)).long()

        return padded, lengths, max_length

    def __len__(self): return len(self.tokens)

    def __getitem__(self, idx: int) -> typing.Tuple[typing.List[int], int]:
        return self.tokens[idx], self.token_length[idx]

