import pathlib
import io
from itertools import chain

import torch
from torchtext.utils import unicode_csv_reader
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from ..utils import ROOT_DIR

DATAPATH: pathlib.Path = ROOT_DIR / "data"


def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    # torch.Tensor.cumsum returns the cumulative sum
    # of elements in the dimension dim.
    # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)

    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return (text, offsets), label


def _csv_iterator(data_path, ngrams, yield_cls=False):
    tokenizer = get_tokenizer("basic_english")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            tokens = ' '.join(row[1:])
            tokens = tokenizer(tokens)
            if yield_cls:
                yield int(row[0]) - 1, ngrams_iterator(tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)


def build_text_vocab(datapath_1=str(DATAPATH / "yelp_review_polarity/yelp_review_polarity_csv/train.csv"),
                     datapath_2=str(DATAPATH / "ag_news/ag_news_csv/train.csv")):
    # builds vocabulary based on training data of two different datasets
    iterator_1 = _csv_iterator(datapath_1, ngrams=1)
    iterator_2 = _csv_iterator(datapath_2, ngrams=1)
    iterator_tot = chain(iterator_1, iterator_2)
    vocab = build_vocab_from_iterator(iterator_tot)  # len(vocab)=515499 for default
    return vocab
