from collections import deque
from typing import Tuple

import numpy as np
import torch
from torch.utils.data import IterableDataset


class BufferedSequencer(IterableDataset):

    def __init__(self, dataset: IterableDataset, sequence_length: int,
                 min_sequence_length: int = 0, drop_last: bool = False):

        self.dataset = iter(dataset)
        self.sequence_length = sequence_length
        self.min_sequence_length = min_sequence_length
        self.drop_last = drop_last

        self.text_buffer = deque()
        self.stop_iter_flag = False

    def _get_sample_from_dataset(self, as_array=True):
        try:
            sample = next(self.dataset)
            text = sample['tokens']
            if as_array and not isinstance(text, np.ndarray):
                text = np.array(text, dtype=np.int32)
            if not as_array and isinstance(text, np.ndarray):
                text = text.tolist()
            return text
        except StopIteration:
            raise StopIteration

    def get_sample(self):

        if self.stop_iter_flag:
            self.stop_iter_flag = False
            raise StopIteration

        while sum(len(t) for t in self.text_buffer) - len(self.text_buffer) < self.sequence_length:
            try:
                text = self._get_sample_from_dataset(as_array=False)
            except StopIteration:
                if not self.text_buffer:
                    raise StopIteration
                if not self.drop_last and sum(len(t) for t in self.text_buffer) >= self.min_sequence_length:
                    sample_text = []
                    sample_pos = []
                    while self.text_buffer:
                        text = self.text_buffer.popleft()
                        sample_text.append(text)
                        sample_pos.append(len(text))
                    self.stop_iter_flag = True
                    return sample_text, sample_pos
                break
            self.text_buffer.append(text)

        sample_text = []
        sample_pos = []
        while self.text_buffer:
            text = self.text_buffer.popleft()

            remaining_space = self.sequence_length - (sum(sample_pos) - len(sample_pos))

            if len(text) <= remaining_space:
                sample_text.append(text)
                sample_pos.append(len(text))
                if sum(sample_pos) - len(sample_pos) == self.sequence_length:
                    return sample_text, sample_pos
            else:
                sample_text.append(text[:remaining_space])
                sample_pos.append(remaining_space)
                self.text_buffer.append(text[remaining_space:])
                return sample_text, sample_pos

        raise StopIteration

    def __iter__(self):
        return self

    def __next__(self) -> Tuple[torch.Tensor, torch.Tensor]:
        try:
            sequence, positions = self.get_sample()
        except StopIteration:
            raise StopIteration

        if isinstance(sequence, list):
            sequence = [torch.tensor(seq, dtype=torch.long) for seq in sequence]
        return sequence, positions
