#! -*- coding: utf-8
import typing

import numpy as np
import torch


class TwentyNewsgroupDataset(torch.utils.data.Dataset):
    def __init__(self, datafile: str, targetfile,
                 pad_token_id: int = 0, delimiter: str = " "):
        super().__init__()
        self.datafile = datafile
        self.targetfile = targetfile
        self.pad_token_id = pad_token_id
        self.delimiter = delimiter

        self.tokens, self.token_length, self.max_length = self.read_datafile()
        self.targets = self.read_targetfile()

    def read_datafile(self) -> typing.Tuple[torch.Tensor, torch.Tensor, int]:
        with open(self.datafile) as f:
            lines = f.readlines()
        tokens = []
        for line in lines:
            t = [int(t) for t in line.strip().split(self.delimiter)]
            tokens.append(t)
        lengths = [len(t) for t in tokens]
        max_length = int(np.max(lengths))

        padded = []
        for t in tokens:
            pad_size = max_length - len(t)
            padded.append(t + [self.pad_token_id] * pad_size)
        padded = torch.Tensor(np.array(padded)).long()
        lengths = torch.Tensor(np.array(lengths)).long()

        return padded, lengths, max_length

    def read_targetfile(self) -> torch.Tensor:
        with open(self.targetfile) as f:
            lines = f.readlines()
        targets = []
        for line in lines:
            targets.extend([int(c)
                           for c in line.strip().split(self.delimiter)])
        return torch.tensor(targets).long()

    def __len__(self): return len(self.tokens)

    def __getitem__(self, idx: int) -> typing.Tuple[typing.List[int], int, int]:
        return self.tokens[idx], self.targets[idx], self.token_length[idx]
